Upload
others
View
264
Download
1
Embed Size (px)
Citation preview
pyspark DocumentationRelease master
Author
Dec 17, 2020
CONTENTS
1 Getting Started 31.1 Installation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31.2 Quickstart . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
2 User Guide 152.1 Apache Arrow in PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 152.2 Python Package Management . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 24
3 API Reference 293.1 Spark SQL . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 293.2 Structured Streaming . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2543.3 ML . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2803.4 Spark Streaming . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10233.5 MLlib . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10413.6 Spark Core . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 11923.7 Resource Management . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1271
4 Development 12774.1 Contributing to PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12774.2 Testing PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12794.3 Debugging PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12794.4 Setting up IDEs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1284
5 Migration Guide 12875.1 Upgrading from PySpark 2.4 to 3.0 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12875.2 Upgrading from PySpark 2.3 to 2.4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.3 Upgrading from PySpark 2.3.0 to 2.3.1 and above . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.4 Upgrading from PySpark 2.2 to 2.3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.5 Upgrading from PySpark 1.4 to 1.5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.6 Upgrading from PySpark 1.0-1.2 to 1.3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1289
Bibliography 1291
Index 1293
i
ii
pyspark Documentation, Release master
Live Notebook | GitHub | Issues | Examples | Community
PySpark is an interface for Apache Spark in Python. It not only allows you to write Spark applications using PythonAPIs, but also provides the PySpark shell for interactively analyzing your data in a distributed environment. PySparksupports most of Spark’s features such as Spark SQL, DataFrame, Streaming, MLlib (Machine Learning) and SparkCore.
Spark SQL and DataFrame
Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrameand can also act as distributed SQL query engine.
Streaming
Running on top of Spark, the streaming feature in Apache Spark enables powerful interactive and analytical applica-tions across both streaming and historical data, while inheriting Spark’s ease of use and fault tolerance characteristics.
MLlib
Built on top of Spark, MLlib is a scalable machine learning library that provides a uniform set of high-level APIs thathelp users create and tune practical machine learning pipelines.
Spark Core
Spark Core is the underlying general execution engine for the Spark platform that all other functionality is built on topof. It provides an RDD (Resilient Distributed Dataset) and in-memory computing capabilities.
CONTENTS 1
pyspark Documentation, Release master
2 CONTENTS
CHAPTER
ONE
GETTING STARTED
This page summarizes the basic steps required to setup and get started with PySpark.
1.1 Installation
PySpark is included in the official releases of Spark available in the Apache Spark website. For Python users, PySparkalso provides pip installation from PyPI. This is usually for local usage or as a client to connect to a cluster insteadof setting up a cluster itself.
This page includes instructions for installing PySpark by using pip, Conda, downloading manually, and building fromthe source.
1.1.1 Python Version Supported
Python 3.6 and above.
1.1.2 Using PyPI
PySpark installation using PyPI is as follows:
pip install pyspark
If you want to install extra dependencies for a specific component, you can install it as below:
pip install pyspark[sql]
For PySpark with/without a specific Hadoop version, you can install it by using HADOOP_VERSION environmentvariables as below:
HADOOP_VERSION=2.7 pip install pyspark
The default distribution uses Hadoop 3.2 and Hive 2.3. If users specify different versions of Hadoop, the pip installationautomatically downloads a different version and use it in PySpark. Downloading it can take a while depending on thenetwork and the mirror chosen. PYSPARK_RELEASE_MIRROR can be set to manually choose the mirror for fasterdownloading.
PYSPARK_RELEASE_MIRROR=http://mirror.apache-kr.org HADOOP_VERSION=2.7 pip install
It is recommended to use -v option in pip to track the installation and download status.
3
pyspark Documentation, Release master
HADOOP_VERSION=2.7 pip install pyspark -v
Supported values in HADOOP_VERSION are:
• without: Spark pre-built with user-provided Apache Hadoop
• 2.7: Spark pre-built for Apache Hadoop 2.7
• 3.2: Spark pre-built for Apache Hadoop 3.2 and later (default)
Note that this installation way of PySpark with/without a specific Hadoop version is experimental. It can change or beremoved between minor releases.
1.1.3 Using Conda
Conda is an open-source package management and environment management system which is a part of the Anacondadistribution. It is both cross-platform and language agnostic. In practice, Conda can replace both pip and virtualenv.
Create new virtual environment from your terminal as shown below:
conda create -n pyspark_env
After the virtual environment is created, it should be visible under the list of Conda environments which can be seenusing the following command:
conda env list
Now activate the newly created environment with the following command:
conda activate pyspark_env
You can install pyspark by Using PyPI to install PySpark in the newly created environment, for example as below. Itwill install PySpark under the new virtual environment pyspark_env created above.
pip install pyspark
Alternatively, you can install PySpark from Conda itself as below:
conda install pyspark
However, note that PySpark at Conda is not necessarily synced with PySpark release cycle because it is maintained bythe community separately.
1.1.4 Manually Downloading
PySpark is included in the distributions available at the Apache Spark website. You can download a distribution youwant from the site. After that, uncompress the tar file into the directory where you want to install Spark, for example,as below:
tar xzvf spark-3.0.0-bin-hadoop2.7.tgz
Ensure the SPARK_HOME environment variable points to the directory where the tar file has been extracted. UpdatePYTHONPATH environment variable such that it can find the PySpark and Py4J under SPARK_HOME/python/lib.One example of doing this is shown below:
4 Chapter 1. Getting Started
pyspark Documentation, Release master
cd spark-3.0.0-bin-hadoop2.7export SPARK_HOME=`pwd`export PYTHONPATH=$(ZIPS=("$SPARK_HOME"/python/lib/*.zip); IFS=:; echo "${ZIPS[*]}"):→˓$PYTHONPATH
1.1.5 Installing from Source
To install PySpark from source, refer to Building Spark.
1.1.6 Dependencies
Package Minimum supported version Notepandas 0.23.2 Optional for SQLNumPy 1.7 Required for MLpyarrow 1.0.0 Optional for SQLPy4J 0.10.9 Required
Note that PySpark requires Java 8 or later with JAVA_HOME properly set. If using JDK 11, set -Dio.netty.tryReflectionSetAccessible=true for Arrow related features and refer to Downloading.
1.2 Quickstart
This is a short introduction and quickstart for the PySpark DataFrame API. PySpark DataFrames are lazily evaluated.They are implemented on top of RDDs. When Spark transforms data, it does not immediately compute the transforma-tion but plans how to compute later. When actions such as collect() are explicitly called, the computation starts.This notebook shows the basic usages of the DataFrame, geared mainly for new users. You can run the latest versionof these examples by yourself on a live notebook here.
There is also other useful information in Apache Spark documentation site, see the latest version of Spark SQL andDataFrames, RDD Programming Guide, Structured Streaming Programming Guide, Spark Streaming ProgrammingGuide and Machine Learning Library (MLlib) Guide.
PySpark applications start with initializing SparkSession which is the entry point of PySpark as below. In case ofrunning it in PySpark shell via pyspark executable, the shell automatically creates the session in the variable spark forusers.
[1]: from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
1.2. Quickstart 5
pyspark Documentation, Release master
1.2.1 DataFrame Creation
A PySpark DataFrame can be created via pyspark.sql.SparkSession.createDataFrame typically bypassing a list of lists, tuples, dictionaries and pyspark.sql.Rows, a pandas DataFrame and an RDD consistingof such a list. pyspark.sql.SparkSession.createDataFrame takes the schema argument to specify theschema of the DataFrame. When it is omitted, PySpark infers the corresponding schema by taking a sample from thedata.
Firstly, you can create a PySpark DataFrame from a list of rows
[2]: from datetime import datetime, dateimport pandas as pdfrom pyspark.sql import Row
df = spark.createDataFrame([Row(a=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0)),Row(a=2, b=3., c='string2', d=date(2000, 2, 1), e=datetime(2000, 1, 2, 12, 0)),Row(a=4, b=5., c='string3', d=date(2000, 3, 1), e=datetime(2000, 1, 3, 12, 0))
])df
[2]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]
Create a PySpark DataFrame with an explicit schema.
[3]: df = spark.createDataFrame([(1, 2., 'string1', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),(2, 3., 'string2', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),(3, 4., 'string3', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))
], schema='a long, b double, c string, d date, e timestamp')df
[3]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]
Create a PySpark DataFrame from a pandas DataFrame
[4]: pandas_df = pd.DataFrame({'a': [1, 2, 3],'b': [2., 3., 4.],'c': ['string1', 'string2', 'string3'],'d': [date(2000, 1, 1), date(2000, 2, 1), date(2000, 3, 1)],'e': [datetime(2000, 1, 1, 12, 0), datetime(2000, 1, 2, 12, 0), datetime(2000, 1,
→˓3, 12, 0)]})df = spark.createDataFrame(pandas_df)df
[4]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]
Create a PySpark DataFrame from an RDD consisting of a list of tuples.
[5]: rdd = spark.sparkContext.parallelize([(1, 2., 'string1', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),(2, 3., 'string2', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),(3, 4., 'string3', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))
])df = spark.createDataFrame(rdd, schema=['a', 'b', 'c', 'd', 'e'])df
6 Chapter 1. Getting Started
pyspark Documentation, Release master
[5]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]
The DataFrames created above all have the same results and schema.
[6]: # All DataFrames above result same.df.show()df.printSchema()
+---+---+-------+----------+-------------------+| a| b| c| d| e|+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|| 2|3.0|string2|2000-02-01|2000-01-02 12:00:00|| 3|4.0|string3|2000-03-01|2000-01-03 12:00:00|+---+---+-------+----------+-------------------+
root|-- a: long (nullable = true)|-- b: double (nullable = true)|-- c: string (nullable = true)|-- d: date (nullable = true)|-- e: timestamp (nullable = true)
1.2.2 Viewing Data
The top rows of a DataFrame can be displayed using DataFrame.show().
[7]: df.show(1)
+---+---+-------+----------+-------------------+| a| b| c| d| e|+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|+---+---+-------+----------+-------------------+only showing top 1 row
Alternatively, you can enable spark.sql.repl.eagerEval.enabled configuration for the eager evaluationof PySpark DataFrame in notebooks such as Jupyter. The number of rows to show can be controlled via spark.sql.repl.eagerEval.maxNumRows configuration.
[8]: spark.conf.set('spark.sql.repl.eagerEval.enabled', True)df
[8]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]
The rows can also be shown vertically. This is useful when rows are too long to show horizontally.
[9]: df.show(1, vertical=True)
-RECORD 0------------------a | 1b | 2.0c | string1d | 2000-01-01e | 2000-01-01 12:00:00
(continues on next page)
1.2. Quickstart 7
pyspark Documentation, Release master
(continued from previous page)
only showing top 1 row
You can see the DataFrame’s schema and column names as follows:
[10]: df.columns
[10]: ['a', 'b', 'c', 'd', 'e']
[11]: df.printSchema()
root|-- a: long (nullable = true)|-- b: double (nullable = true)|-- c: string (nullable = true)|-- d: date (nullable = true)|-- e: timestamp (nullable = true)
Show the summary of the DataFrame
[12]: df.select("a", "b", "c").describe().show()
+-------+---+---+-------+|summary| a| b| c|+-------+---+---+-------+| count| 3| 3| 3|| mean|2.0|3.0| null|| stddev|1.0|1.0| null|| min| 1|2.0|string1|| max| 3|4.0|string3|+-------+---+---+-------+
DataFrame.collect() collects the distributed data to the driver side as the local data in Python. Note that thiscan throw an out-of-memory error when the dataset is too large to fit in the driver side because it collects all the datafrom executors to the driver side.
[13]: df.collect()
[13]: [Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000,→˓1, 1, 12, 0)),Row(a=2, b=3.0, c='string2', d=datetime.date(2000, 2, 1), e=datetime.datetime(2000,→˓1, 2, 12, 0)),Row(a=3, b=4.0, c='string3', d=datetime.date(2000, 3, 1), e=datetime.datetime(2000,→˓1, 3, 12, 0))]
In order to avoid throwing an out-of-memory exception, use DataFrame.take() or DataFrame.tail().
[14]: df.take(1)
[14]: [Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000,→˓1, 1, 12, 0))]
PySpark DataFrame also provides the conversion back to a pandas DataFrame to leverage pandas APIs. Note thattoPandas also collects all data into the driver side that can easily cause an out-of-memory-error when the data is toolarge to fit into the driver side.
8 Chapter 1. Getting Started
pyspark Documentation, Release master
[15]: df.toPandas()
[15]: a b c d e0 1 2.0 string1 2000-01-01 2000-01-01 12:00:001 2 3.0 string2 2000-02-01 2000-01-02 12:00:002 3 4.0 string3 2000-03-01 2000-01-03 12:00:00
1.2.3 Selecting and Accessing Data
PySpark DataFrame is lazily evaluated and simply selecting a column does not trigger the computation but it returns aColumn instance.
[16]: df.a
[16]: Column<b'a'>
In fact, most of column-wise operations return Columns.
[17]: from pyspark.sql import Columnfrom pyspark.sql.functions import upper
type(df.c) == type(upper(df.c)) == type(df.c.isNull())
[17]: True
These Columns can be used to select the columns from a DataFrame. For example, DataFrame.select() takesthe Column instances that returns another DataFrame.
[18]: df.select(df.c).show()
+-------+| c|+-------+|string1||string2||string3|+-------+
Assign new Column instance.
[19]: df.withColumn('upper_c', upper(df.c)).show()
+---+---+-------+----------+-------------------+-------+| a| b| c| d| e|upper_c|+---+---+-------+----------+-------------------+-------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1|| 2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2|| 3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|+---+---+-------+----------+-------------------+-------+
To select a subset of rows, use DataFrame.filter().
[20]: df.filter(df.a == 1).show()
+---+---+-------+----------+-------------------+| a| b| c| d| e|
(continues on next page)
1.2. Quickstart 9
pyspark Documentation, Release master
(continued from previous page)
+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|+---+---+-------+----------+-------------------+
1.2.4 Applying a Function
PySpark supports various UDFs and APIs to allow users to execute Python native functions. See also the latest PandasUDFs and Pandas Function APIs. For instance, the example below allows users to directly use the APIs in a pandasSeries within Python native function.
[21]: import pandasfrom pyspark.sql.functions import pandas_udf
@pandas_udf('long')def pandas_plus_one(series: pd.Series) -> pd.Series:
# Simply plus one by using pandas Series.return series + 1
df.select(pandas_plus_one(df.a)).show()
+------------------+|pandas_plus_one(a)|+------------------+| 2|| 3|| 4|+------------------+
Another example is DataFrame.mapInPandas which allows users directly use the APIs in a pandas DataFramewithout any restrictions such as the result length.
[22]: def pandas_filter_func(iterator):for pandas_df in iterator:
yield pandas_df[pandas_df.a == 1]
df.mapInPandas(pandas_filter_func, schema=df.schema).show()
+---+---+-------+----------+-------------------+| a| b| c| d| e|+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|+---+---+-------+----------+-------------------+
10 Chapter 1. Getting Started
pyspark Documentation, Release master
1.2.5 Grouping Data
PySpark DataFrame also provides a way of handling grouped data by using the common approach, split-apply-combinestrategy. It groups the data by a certain condition applies a function to each group and then combines them back to theDataFrame.
[23]: df = spark.createDataFrame([['red', 'banana', 1, 10], ['blue', 'banana', 2, 20], ['red', 'carrot', 3, 30],['blue', 'grape', 4, 40], ['red', 'carrot', 5, 50], ['black', 'carrot', 6, 60],['red', 'banana', 7, 70], ['red', 'grape', 8, 80]], schema=['color', 'fruit', 'v1
→˓', 'v2'])df.show()
+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+
Grouping and then applying the avg() function to the resulting groups.
[24]: df.groupby('color').avg().show()
+-----+-------+-------+|color|avg(v1)|avg(v2)|+-----+-------+-------+| red| 4.8| 48.0||black| 6.0| 60.0|| blue| 3.0| 30.0|+-----+-------+-------+
You can also apply a Python native function against each group by using pandas APIs.
[25]: def plus_mean(pandas_df):return pandas_df.assign(v1=pandas_df.v1 - pandas_df.v1.mean())
df.groupby('color').applyInPandas(plus_mean, schema=df.schema).show()
+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| -3| 10|| red|carrot| -1| 30|| red|carrot| 0| 50|| red|banana| 2| 70|| red| grape| 3| 80||black|carrot| 0| 60|| blue|banana| -1| 20|| blue| grape| 1| 40|
(continues on next page)
1.2. Quickstart 11
pyspark Documentation, Release master
(continued from previous page)
+-----+------+---+---+
Co-grouping and applying a function.
[26]: df1 = spark.createDataFrame([(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],('time', 'id', 'v1'))
df2 = spark.createDataFrame([(20000101, 1, 'x'), (20000101, 2, 'y')],('time', 'id', 'v2'))
def asof_join(l, r):return pd.merge_asof(l, r, on='time', by='id')
df1.groupby('id').cogroup(df2.groupby('id')).applyInPandas(asof_join, schema='time int, id int, v1 double, v2 string').show()
+--------+---+---+---+| time| id| v1| v2|+--------+---+---+---+|20000101| 1|1.0| x||20000102| 1|3.0| x||20000101| 2|2.0| y||20000102| 2|4.0| y|+--------+---+---+---+
1.2.6 Getting Data in/out
CSV is straightforward and easy to use. Parquet and ORC are efficient and compact file formats to read and writefaster.
There are many other data sources available in PySpark such as JDBC, text, binaryFile, Avro, etc. See also the latestSpark SQL, DataFrames and Datasets Guide in Apache Spark documentation.
CSV
[27]: df.write.csv('foo.csv', header=True)spark.read.csv('foo.csv', header=True).show()
+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+
12 Chapter 1. Getting Started
pyspark Documentation, Release master
Parquet
[28]: df.write.parquet('bar.parquet')spark.read.parquet('bar.parquet').show()
+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+
ORC
[29]: df.write.orc('zoo.orc')spark.read.orc('zoo.orc').show()
+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+
1.2.7 Working with SQL
DataFrame and Spark SQL share the same execution engine so they can be interchangeably used seamlessly. Forexample, you can register the DataFrame as a table and run a SQL easily as below:
[30]: df.createOrReplaceTempView("tableA")spark.sql("SELECT count(*) from tableA").show()
+--------+|count(1)|+--------+| 8|+--------+
In addition, UDFs can be registered and invoked in SQL out of the box:
1.2. Quickstart 13
pyspark Documentation, Release master
[31]: @pandas_udf("integer")def add_one(s: pd.Series) -> pd.Series:
return s + 1
spark.udf.register("add_one", add_one)spark.sql("SELECT add_one(v1) FROM tableA").show()
+-----------+|add_one(v1)|+-----------+| 2|| 3|| 4|| 5|| 6|| 7|| 8|| 9|+-----------+
These SQL expressions can directly be mixed and used as PySpark columns.
[32]: from pyspark.sql.functions import expr
df.selectExpr('add_one(v1)').show()df.select(expr('count(*)') > 0).show()
+-----------+|add_one(v1)|+-----------+| 2|| 3|| 4|| 5|| 6|| 7|| 8|| 9|+-----------+
+--------------+|(count(1) > 0)|+--------------+| true|+--------------+
14 Chapter 1. Getting Started
CHAPTER
TWO
USER GUIDE
2.1 Apache Arrow in PySpark
Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer data between JVMand Python processes. This currently is most beneficial to Python users that work with Pandas/NumPy data. Its usageis not automatic and might require some minor changes to configuration or code to take full advantage and ensurecompatibility. This guide will give a high-level description of how to use Arrow in Spark and highlight any differenceswhen working with Arrow-enabled data.
2.1.1 Ensure PyArrow Installed
To use Apache Arrow in PySpark, the recommended version of PyArrow should be installed. If you install PySparkusing pip, then PyArrow can be brought in as an extra dependency of the SQL module with the command pipinstall pyspark[sql]. Otherwise, you must ensure that PyArrow is installed and available on all clusternodes. You can install using pip or conda from the conda-forge channel. See PyArrow installation for details.
2.1.2 Enabling for Conversion to/from Pandas
Arrow is available as an optimization when converting a Spark DataFrame to a Pandas DataFrame usingthe call DataFrame.toPandas() and when creating a Spark DataFrame from a Pandas DataFrame withSparkSession.createDataFrame(). To use Arrow when executing these calls, users need to first set theSpark configuration spark.sql.execution.arrow.pyspark.enabled to true. This is disabled by de-fault.
In addition, optimizations enabled by spark.sql.execution.arrow.pyspark.enabled could fallback au-tomatically to non-Arrow optimization implementation if an error occurs before the actual computation within Spark.This can be controlled by spark.sql.execution.arrow.pyspark.fallback.enabled.
import numpy as np # type: ignore[import]import pandas as pd # type: ignore[import]
# Enable Arrow-based columnar data transfersspark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
# Generate a Pandas DataFramepdf = pd.DataFrame(np.random.rand(100, 3))
# Create a Spark DataFrame from a Pandas DataFrame using Arrowdf = spark.createDataFrame(pdf)
(continues on next page)
15
pyspark Documentation, Release master
(continued from previous page)
# Convert the Spark DataFrame back to a Pandas DataFrame using Arrowresult_pdf = df.select("*").toPandas()
Using the above optimizations with Arrow will produce the same results as when Arrow is not enabled.
Note that even with Arrow, DataFrame.toPandas() results in the collection of all records in the DataFrameto the driver program and should be done on a small subset of the data. Not all Spark data types are currentlysupported and an error can be raised if a column has an unsupported type. If an error occurs during SparkSession.createDataFrame(), Spark will fall back to create the DataFrame without Arrow.
2.1.3 Pandas UDFs (a.k.a. Vectorized UDFs)
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to workwith the data, which allows vectorized operations. A Pandas UDF is defined using the pandas_udf() as a decoratoror to wrap the function, and no additional configuration is required. A Pandas UDF behaves as a regular PySparkfunction API in general.
Before Spark 3.0, Pandas UDFs used to be defined with pyspark.sql.functions.PandasUDFType. FromSpark 3.0 with Python 3.6+, you can also use Python type hints. Using Python type hints is preferred and usingpyspark.sql.functions.PandasUDFType will be deprecated in the future release.
Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrameshould be used for its input or output type hint instead when the input or output column is of StructType. Thefollowing example shows a Pandas UDF which takes long column, string column and struct column, and outputs astruct column. It requires the function to specify the type hints of pandas.Series and pandas.DataFrame asbelow:
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf("col1 string, col2 long")def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:
s3['col2'] = s1 + s2.str.len()return s3
# Create a Spark DataFrame that has three columns including a struct column.df = spark.createDataFrame(
[[1, "a string", ("a nested string",)]],"long_col long, string_col string, struct_col struct<col1:string>")
df.printSchema()# root# |-- long_column: long (nullable = true)# |-- string_column: string (nullable = true)# |-- struct_column: struct (nullable = true)# | |-- col1: string (nullable = true)
df.select(func("long_col", "string_col", "struct_col")).printSchema()# |-- func(long_col, string_col, struct_col): struct (nullable = true)# | |-- col1: string (nullable = true)# | |-- col2: long (nullable = true)
In the following sections, it describes the combinations of the supported type hints. For simplicity, pandas.DataFrame variant is omitted.
16 Chapter 2. User Guide
pyspark Documentation, Release master
Series to Series
The type hint can be expressed as pandas.Series, . . . -> pandas.Series.
By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF where the givenfunction takes one or more pandas.Series and outputs one pandas.Series. The output of the function shouldalways be of the same length as the input. Internally, PySpark will execute a Pandas UDF by splitting columns intobatches and calling the function for each batch as a subset of the data, then concatenating the results together.
The following example shows how to create this Pandas UDF that computes the product of 2 columns.
import pandas as pd
from pyspark.sql.functions import col, pandas_udffrom pyspark.sql.types import LongType
# Declare the function and create the UDFdef multiply_func(a: pd.Series, b: pd.Series) -> pd.Series:
return a * b
multiply = pandas_udf(multiply_func, returnType=LongType())
# The function for a pandas_udf should be able to execute with local Pandas datax = pd.Series([1, 2, 3])print(multiply_func(x, x))# 0 1# 1 4# 2 9# dtype: int64
# Create a Spark DataFrame, 'spark' is an existing SparkSessiondf = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))
# Execute function as a Spark vectorized UDFdf.select(multiply(col("x"), col("x"))).show()# +-------------------+# |multiply_func(x, x)|# +-------------------+# | 1|# | 4|# | 9|# +-------------------+
For detailed usage, please see pandas_udf().
Iterator of Series to Iterator of Series
The type hint can be expressed as Iterator[pandas.Series] -> Iterator[pandas.Series].
By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF where the givenfunction takes an iterator of pandas.Series and outputs an iterator of pandas.Series. The length of the entireoutput from the function should be the same length of the entire input; therefore, it can prefetch the data from the inputiterator as long as the lengths are the same. In this case, the created Pandas UDF requires one input column when thePandas UDF is called. To use multiple input columns, a different type hint is required. See Iterator of Multiple Seriesto Iterator of Series.
It is also useful when the UDF execution requires initializing some states although internally it works identically asSeries to Series case. The pseudocode below illustrates the example.
2.1. Apache Arrow in PySpark 17
pyspark Documentation, Release master
@pandas_udf("long")def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# Do some expensive initialization with a statestate = very_expensive_initialization()for x in iterator:
# Use that state for whole iterator.yield calculate_with_state(x, state)
df.select(calculate("value")).show()
The following example shows how to create this Pandas UDF:
from typing import Iterator
import pandas as pd
from pyspark.sql.functions import pandas_udf
pdf = pd.DataFrame([1, 2, 3], columns=["x"])df = spark.createDataFrame(pdf)
# Declare the function and create the UDF@pandas_udf("long")def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for x in iterator:yield x + 1
df.select(plus_one("x")).show()# +-----------+# |plus_one(x)|# +-----------+# | 2|# | 3|# | 4|# +-----------+
For detailed usage, please see pandas_udf().
Iterator of Multiple Series to Iterator of Series
The type hint can be expressed as Iterator[Tuple[pandas.Series, ...]] -> Iterator[pandas.Series].
By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF where the givenfunction takes an iterator of a tuple of multiple pandas.Series and outputs an iterator of pandas.Series. Inthis case, the created pandas UDF requires multiple input columns as many as the series in the tuple when the PandasUDF is called. Otherwise, it has the same characteristics and restrictions as Iterator of Series to Iterator of Series case.
The following example shows how to create this Pandas UDF:
from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf
(continues on next page)
18 Chapter 2. User Guide
pyspark Documentation, Release master
(continued from previous page)
pdf = pd.DataFrame([1, 2, 3], columns=["x"])df = spark.createDataFrame(pdf)
# Declare the function and create the UDF@pandas_udf("long")def multiply_two_cols(
iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:for a, b in iterator:
yield a * b
df.select(multiply_two_cols("x", "x")).show()# +-----------------------+# |multiply_two_cols(x, x)|# +-----------------------+# | 1|# | 4|# | 9|# +-----------------------+
For detailed usage, please see pandas_udf().
Series to Scalar
The type hint can be expressed as pandas.Series, . . . -> Any.
By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF similar to PyS-park’s aggregate functions. The given function takes pandas.Series and returns a scalar value. The return type shouldbe a primitive data type, and the returned scalar can be either a python primitive type, e.g., int or float or a numpydata type, e.g., numpy.int64 or numpy.float64. Any should ideally be a specific scalar type accordingly.
This UDF can be also used with GroupedData.agg() and Window. It defines an aggregation from one or morepandas.Series to a scalar value, where each pandas.Series represents a column within the group or window.
Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded intomemory. Also, only unbounded window is supported with Grouped aggregate Pandas UDFs currently. The followingexample shows how to use this type of UDF to compute mean with a group-by and window operations:
import pandas as pd
from pyspark.sql.functions import pandas_udffrom pyspark.sql import Window
df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v"))
# Declare the function and create the UDF@pandas_udf("double")def mean_udf(v: pd.Series) -> float:
return v.mean()
df.select(mean_udf(df['v'])).show()# +-----------+# |mean_udf(v)|# +-----------+# | 4.2|
(continues on next page)
2.1. Apache Arrow in PySpark 19
pyspark Documentation, Release master
(continued from previous page)
# +-----------+
df.groupby("id").agg(mean_udf(df['v'])).show()# +---+-----------+# | id|mean_udf(v)|# +---+-----------+# | 1| 1.5|# | 2| 6.0|# +---+-----------+
w = Window \.partitionBy('id') \.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()# +---+----+------+# | id| v|mean_v|# +---+----+------+# | 1| 1.0| 1.5|# | 1| 2.0| 1.5|# | 2| 3.0| 6.0|# | 2| 5.0| 6.0|# | 2|10.0| 6.0|# +---+----+------+
For detailed usage, please see pandas_udf().
2.1.4 Pandas Function APIs
Pandas Function APIs can directly apply a Python native function against the whole DataFrame by using Pandasinstances. Internally it works similarly with Pandas UDFs by using Arrow to transfer data and Pandas to work with thedata, which allows vectorized operations. However, a Pandas Function API behaves as a regular API under PySparkDataFrame instead of Column, and Python type hints in Pandas Functions APIs are optional and do not affect howit works internally at this moment although they might be required in the future.
From Spark 3.0, grouped map pandas UDF is now categorized as a separate Pandas Function API,DataFrame.groupby().applyInPandas(). It is still possible to use it with pyspark.sql.functions.PandasUDFType and DataFrame.groupby().apply() as it was; however, it is pre-ferred to use DataFrame.groupby().applyInPandas() directly. Using pyspark.sql.functions.PandasUDFType will be deprecated in the future.
Grouped Map
Grouped map operations with Pandas instances are supported by DataFrame.groupby().applyInPandas()which requires a Python function that takes a pandas.DataFrame and return another pandas.DataFrame. Itmaps each group to each pandas.DataFrame in the Python function.
This API implements the “split-apply-combine” pattern which consists of three steps:
• Split the data into groups by using DataFrame.groupBy().
• Apply a function on each group. The input and output of the function are both pandas.DataFrame. Theinput data contains all the rows and columns for each group.
• Combine the results into a new PySpark DataFrame.
To use DataFrame.groupBy().applyInPandas(), the user needs to define the following:
20 Chapter 2. User Guide
pyspark Documentation, Release master
• A Python function that defines the computation for each group.
• A StructType object or a string that defines the schema of the output PySpark DataFrame.
The column labels of the returned pandas.DataFrame must either match the field names in the defined outputschema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. Seepandas.DataFrame on how to label columns when constructing a pandas.DataFrame.
Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memoryexceptions, especially if the group sizes are skewed. The configuration for maxRecordsPerBatch is not applied ongroups and it is up to the user to ensure that the grouped data will fit into the available memory.
The following example shows how to use DataFrame.groupby().applyInPandas() to subtract the meanfrom each value in the group.
df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v"))
def subtract_mean(pdf):# pdf is a pandas.DataFramev = pdf.vreturn pdf.assign(v=v - v.mean())
df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show()# +---+----+# | id| v|# +---+----+# | 1|-0.5|# | 1| 0.5|# | 2|-3.0|# | 2|-1.0|# | 2| 4.0|# +---+----+
For detailed usage, please see please see GroupedData.applyInPandas()
Map
Map operations with Pandas instances are supported by DataFrame.mapInPandas() which maps an itera-tor of pandas.DataFrames to another iterator of pandas.DataFrames that represents the current PySparkDataFrame and returns the result as a PySpark DataFrame. The function takes and outputs an iterator of pandas.DataFrame. It can return the output of arbitrary length in contrast to some Pandas UDFs although internally it workssimilarly with Series to Series Pandas UDF.
The following example shows how to use DataFrame.mapInPandas():
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):for pdf in iterator:
yield pdf[pdf.id == 1]
df.mapInPandas(filter_func, schema=df.schema).show()# +---+---+# | id|age|# +---+---+
(continues on next page)
2.1. Apache Arrow in PySpark 21
pyspark Documentation, Release master
(continued from previous page)
# | 1| 21|# +---+---+
For detailed usage, please see DataFrame.mapInPandas().
Co-grouped Map
Co-grouped map operations with Pandas instances are supported by DataFrame.groupby().cogroup().applyInPandas() which allows two PySpark DataFrames to be cogrouped by a common key and then a Pythonfunction applied to each cogroup. It consists of the following steps:
• Shuffle the data such that the groups of each dataframe which share a key are cogrouped together.
• Apply a function to each cogroup. The input of the function is two pandas.DataFrame (with an optionaltuple representing the key). The output of the function is a pandas.DataFrame.
• Combine the pandas.DataFrames from all groups into a new PySpark DataFrame.
To use groupBy().cogroup().applyInPandas(), the user needs to define the following:
• A Python function that defines the computation for each cogroup.
• A StructType object or a string that defines the schema of the output PySpark DataFrame.
The column labels of the returned pandas.DataFrame must either match the field names in the defined outputschema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. Seepandas.DataFrame. on how to label columns when constructing a pandas.DataFrame.
Note that all data for a cogroup will be loaded into memory before the function is applied. This can lead to outof memory exceptions, especially if the group sizes are skewed. The configuration for maxRecordsPerBatch is notapplied and it is up to the user to ensure that the cogrouped data will fit into the available memory.
The following example shows how to use DataFrame.groupby().cogroup().applyInPandas() to per-form an asof join between two datasets.
import pandas as pd
df1 = spark.createDataFrame([(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],("time", "id", "v1"))
df2 = spark.createDataFrame([(20000101, 1, "x"), (20000101, 2, "y")],("time", "id", "v2"))
def asof_join(l, r):return pd.merge_asof(l, r, on="time", by="id")
df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(asof_join, schema="time int, id int, v1 double, v2 string").show()
# +--------+---+---+---+# | time| id| v1| v2|# +--------+---+---+---+# |20000101| 1|1.0| x|# |20000102| 1|3.0| x|# |20000101| 2|2.0| y|# |20000102| 2|4.0| y|# +--------+---+---+---+
22 Chapter 2. User Guide
pyspark Documentation, Release master
For detailed usage, please see PandasCogroupedOps.applyInPandas()
2.1.5 Usage Notes
Supported SQL Types
Currently, all Spark SQL data types are supported by Arrow-based conversion except ArrayType ofTimestampType, and nested StructType. :class: MapType is only supported when using PyArrow 2.0.0 andabove.
Setting Arrow Batch Size
Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to high memory usagein the JVM. To avoid possible out of memory exceptions, the size of the Arrow record batches can be adjusted bysetting the conf spark.sql.execution.arrow.maxRecordsPerBatch to an integer that will determinethe maximum number of rows for each batch. The default value is 10,000 records per batch. If the number of columnsis large, the value should be adjusted accordingly. Using this limit, each data partition will be made into 1 or morerecord batches for processing.
Timestamp with Time Zone Semantics
Spark internally stores timestamps as UTC values, and timestamp data that is brought in without a specified timezone is converted as local time to UTC with microsecond resolution. When timestamp data is exported or displayedin Spark, the session time zone is used to localize the timestamp values. The session time zone is set with theconfiguration spark.sql.session.timeZone and will default to the JVM system local time zone if not set.Pandas uses a datetime64 type with nanosecond resolution, datetime64[ns], with optional time zone on aper-column basis.
When timestamp data is transferred from Spark to Pandas it will be converted to nanoseconds and each column will beconverted to the Spark session time zone then localized to that time zone, which removes the time zone and displaysvalues as local time. This will occur when calling DataFrame.toPandas() or pandas_udf with timestampcolumns.
When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This occurswhen calling SparkSession.createDataFrame() with a Pandas DataFrame or when returning a timestampfrom a pandas_udf. These conversions are done automatically to ensure Spark will have data in the expectedformat, so it is not necessary to do any of these conversions yourself. Any nanosecond values will be truncated.
Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is different thana Pandas timestamp. It is recommended to use Pandas time series functionality when working with timestamps inpandas_udfs to get the best performance, see here for details.
Recommended Pandas and PyArrow Versions
For usage with pyspark.sql, the minimum supported versions of Pandas is 0.23.2 and PyArrow is 1.0.0. Higher versionsmay be used, however, compatibility and data correctness can not be guaranteed and should be verified by the user.
2.1. Apache Arrow in PySpark 23
pyspark Documentation, Release master
Compatibility Setting for PyArrow >= 0.15.0 and Spark 2.3.x, 2.4.x
Since Arrow 0.15.0, a change in the binary IPC format requires an environment variable to be compatible with previousversions of Arrow <= 0.14.1. This is only necessary to do for PySpark users with versions 2.3.x and 2.4.x that havemanually upgraded PyArrow to 0.15.0. The following can be added to conf/spark-env.sh to use the legacyArrow IPC format:
ARROW_PRE_0_15_IPC_FORMAT=1
This will instruct PyArrow >= 0.15.0 to use the legacy IPC format with the older Arrow Java that is in Spark 2.3.x and2.4.x. Not setting this environment variable will lead to a similar error as described in SPARK-29367 when runningpandas_udfs or DataFrame.toPandas() with Arrow enabled. More information about the Arrow IPC changecan be read on the Arrow 0.15.0 release blog.
2.2 Python Package Management
When you want to run your PySpark application on a cluster such as YARN, Kubernetes, Mesos, etc., you need tomake sure that your code and all used libraries are available on the executors.
As an example let’s say you may want to run the Pandas UDF’s examples. As it uses pyarrow as an underlyingimplementation we need to make sure to have pyarrow installed on each executor on the cluster. Otherwise you mayget errors such as ModuleNotFoundError: No module named 'pyarrow'.
Here is the script app.py from the previous example that will be executed on the cluster:
import pandas as pdfrom pyspark.sql.functions import pandas_udffrom pyspark.sql import SparkSession
def main(spark):df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v"))
@pandas_udf("double")def mean_udf(v: pd.Series) -> float:
return v.mean()
print(df.groupby("id").agg(mean_udf(df['v'])).collect())
if __name__ == "__main__":main(SparkSession.builder.getOrCreate())
There are multiple ways to manage Python dependencies in the cluster:
• Using PySpark Native Features
• Using Conda
• Using Virtualenv
• Using PEX
24 Chapter 2. User Guide
pyspark Documentation, Release master
2.2.1 Using PySpark Native Features
PySpark allows to upload Python files (.py), zipped Python packages (.zip), and Egg files (.egg) to the executorsby:
• Setting the configuration setting spark.submit.pyFiles
• Ssing --py-files option in Spark scripts
• Directly calling pyspark.SparkContext.addPyFile() in applications
This is a straightforward method to ship additional custom Python code to the cluster. You can just add individual filesor zip whole packages and upload them. Using pyspark.SparkContext.addPyFile() allows to upload codeeven after having started your job.
However, it does not allow to add packages built as Wheels and therefore does not allow to include dependencies withnative code.
2.2.2 Using Conda
Conda is one of the most widely-used Python package management systems. PySpark users can directly use a Condaenvironment to ship their third-party Python packages by leveraging conda-pack which is a command line tool creatingrelocatable Conda environments.
The example below creates a Conda environment to use on both the driver and executor and packs it into an archivefile. This archive file captures the Conda environment for Python and stores both Python interpreter and all its relevantdependencies.
conda create -y -n pyspark_conda_env -c conda-forge pyarrow pandas conda-packconda activate pyspark_conda_envconda pack -f -o pyspark_conda_env.tar.gz
After that, you can ship it together with scripts or in the code by using the --archives option or spark.archives configuration (spark.yarn.dist.archives in Yarn). It automatically unpacks the archive onexecutors.
In the case of a spark-submit script, you can use it as follows:
export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonspark-submit --archives pyspark_conda_env.tar.gz#environment app.py
Note that PYSPARK_DRIVER_PYTHON above is not required for cluster modes in Yarn or Kubernetes.
If you’re on a regular Python shell or notebook, you can try it as shown below:
import osfrom pyspark.sql import SparkSessionfrom app import main
os.environ['PYSPARK_PYTHON'] = "./environment/bin/python"spark = SparkSession.builder.config(
"spark.archives", # 'spark.yarn.dist.archives' in Yarn."pyspark_conda_env.tar.gz#environment").getOrCreate()
main(spark)
For a pyspark shell:
2.2. Python Package Management 25
pyspark Documentation, Release master
export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonpyspark --archives pyspark_conda_env.tar.gz#environment
2.2.3 Using Virtualenv
Virtualenv is a Python tool to create isolated Python environments. Since Python 3.3, a subset of its features hasbeen integrated into Python as a standard library under the venv module. PySpark users can use virtualenv to managePython dependencies in their clusters by using venv-pack in a similar way as conda-pack.
A virtual environment to use on both driver and executor can be created as demonstrated below. It packs the currentvirtual environment to an archive file, and It self-contains both Python interpreter and the dependencies.
python -m venv pyspark_venvsource pyspark_venv/bin/activatepip install pyarrow pandas venv-packvenv-pack -o pyspark_venv.tar.gz
You can directly pass/unpack the archive file and enable the environment on executors by leveraging the --archivesoption or spark.archives configuration (spark.yarn.dist.archives in Yarn).
For spark-submit, you can use it by running the command as follows. Also, notice thatPYSPARK_DRIVER_PYTHON is not necessary in Kubernetes or Yarn cluster modes.
export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonspark-submit --archives pyspark_venv.tar.gz#environment app.py
For regular Python shells or notebooks:
import osfrom pyspark.sql import SparkSessionfrom app import main
os.environ['PYSPARK_PYTHON'] = "./environment/bin/python"spark = SparkSession.builder.config(
"spark.archives", # 'spark.yarn.dist.archives' in Yarn."pyspark_venv.tar.gz#environment").getOrCreate()
main(spark)
In the case of a pyspark shell:
export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonpyspark --archives pyspark_venv.tar.gz#environment
26 Chapter 2. User Guide
pyspark Documentation, Release master
2.2.4 Using PEX
PySpark can also use PEX to ship the Python packages together. PEX is a tool that creates a self-contained Pythonenvironment. This is similar to Conda or virtualenv, but a .pex file is executable by itself.
The following example creates a .pex file for the driver and executor to use. The file contains the Python dependenciesspecified with the pex command.
pip install pyarrow pandas pexpex pyspark pyarrow pandas -o pyspark_pex_env.pex
This file behaves similarly with a regular Python interpreter.
./pyspark_pex_env.pex -c "import pandas; print(pandas.__version__)"1.1.5
However, .pex file does not include a Python interpreter itself under the hood so all nodes in a cluster should havethe same Python interpreter installed.
In order to transfer and use the .pex file in a cluster, you should ship it via the spark.files configuration(spark.yarn.dist.files in Yarn) or --files option because they are regular files instead of directoriesor archive files.
For application submission, you run the commands as shown below. Note that PYSPARK_DRIVER_PYTHON is notneeded for cluster modes in Yarn or Kubernetes, and you may also need to set PYSPARK_PYTHON environment vari-able on the AppMaster --conf spark.yarn.appMasterEnv.PYSPARK_PYTHON=./myarchive.pex inYarn cluster mode.
export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./pyspark_pex_env.pexspark-submit --files pyspark_pex_env.pex app.py
For regular Python shells or notebooks:
import osfrom pyspark.sql import SparkSessionfrom app import main
os.environ['PYSPARK_PYTHON'] = "./pyspark_pex_env.pex"spark = SparkSession.builder.config(
"spark.files", # 'spark.yarn.dist.files' in Yarn."pyspark_pex_env.pex").getOrCreate()
main(spark)
For the interactive pyspark shell, the commands are almost the same:
export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./pyspark_pex_env.pexpyspark --files pyspark_pex_env.pex
An end-to-end Docker example for deploying a standalone PySpark with SparkSession.builder and PEX canbe found here - it uses cluster-pack, a library on top of PEX that automatizes the the intermediate step of having tocreate & upload the PEX manually.
2.2. Python Package Management 27
pyspark Documentation, Release master
28 Chapter 2. User Guide
CHAPTER
THREE
API REFERENCE
This page lists an overview of all public PySpark modules, classes, functions and methods.
3.1 Spark SQL
3.1.1 Core Classes
SparkSession(sparkContext[, jsparkSession]) The entry point to programming Spark with the Datasetand DataFrame API.
DataFrame(jdf, sql_ctx) A distributed collection of data grouped into namedcolumns.
Column(jc) A column in a DataFrame.Row(*args, **kwargs) A row in DataFrame.GroupedData(jgd, df) A set of methods for aggregations on a DataFrame,
created by DataFrame.groupBy().PandasCogroupedOps(gd1, gd2) A logical grouping of two GroupedData, created by
GroupedData.cogroup().DataFrameNaFunctions(df) Functionality for working with missing data in
DataFrame.DataFrameStatFunctions(df) Functionality for statistic functions with DataFrame.Window() Utility functions for defining window in DataFrames.
pyspark.sql.SparkSession
class pyspark.sql.SparkSession(sparkContext, jsparkSession=None)The entry point to programming Spark with the Dataset and DataFrame API.
A SparkSession can be used create DataFrame, register DataFrame as tables, execute SQL over tables,cache tables, and read parquet files. To create a SparkSession, use the following builder pattern:
builderA class attribute having a Builder to construct SparkSession instances.
29
pyspark Documentation, Release master
Examples
>>> spark = SparkSession.builder \... .master("local") \... .appName("Word Count") \... .config("spark.some.config.option", "some-value") \... .getOrCreate()
>>> from datetime import datetime>>> from pyspark.sql import Row>>> spark = SparkSession(sc)>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),... time=datetime(2014, 8, 1, 14, 1, 5))])>>> df = allTypes.toDF()>>> df.createOrReplaceTempView("allTypes")>>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '... 'from allTypes where b and i > 0').collect()[Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False,→˓list[1]=2, dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5),→˓a=1)]>>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).→˓collect()[(1, 'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2,→˓3])]
Methods
createDataFrame(data[, schema, . . . ]) Creates a DataFrame from an RDD, a list or apandas.DataFrame.
getActiveSession() Returns the active SparkSession for the currentthread, returned by the builder
newSession() Returns a new SparkSession as new session, that hasseparate SQLConf, registered temporary views andUDFs, but shared SparkContext and table cache.
range(start[, end, step, numPartitions]) Create a DataFrame with single pyspark.sql.types.LongType column named id, containingelements in a range from start to end (exclusive)with step value step.
sql(sqlQuery) Returns a DataFrame representing the result of thegiven query.
stop() Stop the underlying SparkContext.table(tableName) Returns the specified table as a DataFrame.
30 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
builder A class attribute having a Builder to constructSparkSession instances.
catalog Interface through which the user may create, drop,alter or query underlying databases, tables, func-tions, etc.
conf Runtime configuration interface for Spark.read Returns a DataFrameReader that can be used to
read data in as a DataFrame.readStream Returns a DataStreamReader that can be used
to read data streams as a streaming DataFrame.sparkContext Returns the underlying SparkContext.streams Returns a StreamingQueryManager that al-
lows managing all the StreamingQuery in-stances active on this context.
udf Returns a UDFRegistration for UDF registra-tion.
version The version of Spark on which this application isrunning.
pyspark.sql.DataFrame
class pyspark.sql.DataFrame(jdf, sql_ctx)A distributed collection of data grouped into named columns.
A DataFrame is equivalent to a relational table in Spark SQL, and can be created using various functions inSparkSession:
people = spark.read.parquet("...")
Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in:DataFrame, Column.
To select a column from the DataFrame, use the apply method:
ageCol = people.age
A more concrete example:
# To create DataFrame using SparkSessionpeople = spark.read.parquet("...")department = spark.read.parquet("...")
people.filter(people.age > 30).join(department, people.deptId == department.id) \.groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
New in version 1.3.0.
3.1. Spark SQL 31
pyspark Documentation, Release master
Methods
agg(*exprs) Aggregate on the entire DataFramewithout groups(shorthand for df.groupBy().agg()).
alias(alias) Returns a new DataFrame with an alias set.approxQuantile(col, probabilities, relativeEr-ror)
Calculates the approximate quantiles of numericalcolumns of a DataFrame.
cache() Persists the DataFrame with the default storagelevel (MEMORY_AND_DISK).
checkpoint([eager]) Returns a checkpointed version of this Dataset.coalesce(numPartitions) Returns a new DataFrame that has exactly
numPartitions partitions.colRegex(colName) Selects column based on the column name specified
as a regex and returns it as Column.collect() Returns all the records as a list of Row .corr(col1, col2[, method]) Calculates the correlation of two columns of a
DataFrame as a double value.count() Returns the number of rows in this DataFrame.cov(col1, col2) Calculate the sample covariance for the given
columns, specified by their names, as a double value.createGlobalTempView(name) Creates a global temporary view with this
DataFrame.createOrReplaceGlobalTempView(name) Creates or replaces a global temporary view using
the given name.createOrReplaceTempView(name) Creates or replaces a local temporary view with this
DataFrame.createTempView(name) Creates a local temporary view with this
DataFrame.crossJoin(other) Returns the cartesian product with another
DataFrame.crosstab(col1, col2) Computes a pair-wise frequency table of the given
columns.cube(*cols) Create a multi-dimensional cube for the current
DataFrame using the specified columns, so we canrun aggregations on them.
describe(*cols) Computes basic statistics for numeric and stringcolumns.
distinct() Returns a new DataFrame containing the distinctrows in this DataFrame.
drop(*cols) Returns a new DataFrame that drops the specifiedcolumn.
dropDuplicates([subset]) Return a new DataFrame with duplicate rows re-moved, optionally only considering certain columns.
drop_duplicates([subset]) drop_duplicates() is an alias fordropDuplicates().
dropna([how, thresh, subset]) Returns a new DataFrame omitting rows with nullvalues.
exceptAll(other) Return a new DataFrame containing rows in thisDataFrame but not in another DataFrame whilepreserving duplicates.
continues on next page
32 Chapter 3. API Reference
pyspark Documentation, Release master
Table 4 – continued from previous pageexplain([extended, mode]) Prints the (logical and physical) plans to the console
for debugging purpose.fillna(value[, subset]) Replace null values, alias for na.fill().filter(condition) Filters rows using the given condition.first() Returns the first row as a Row .foreach(f) Applies the f function to all Row of this
DataFrame.foreachPartition(f) Applies the f function to each partition of this
DataFrame.freqItems(cols[, support]) Finding frequent items for columns, possibly with
false positives.groupBy(*cols) Groups the DataFrame using the specified
columns, so we can run aggregation on them.groupby(*cols) groupby() is an alias for groupBy().head([n]) Returns the first n rows.hint(name, *parameters) Specifies some hint on the current DataFrame.inputFiles() Returns a best-effort snapshot of the files that com-
pose this DataFrame.intersect(other) Return a new DataFrame containing rows only in
both this DataFrame and another DataFrame.intersectAll(other) Return a new DataFrame containing rows in both
this DataFrame and another DataFrame whilepreserving duplicates.
isLocal() Returns True if the collect() and take()methods can be run locally (without any Spark ex-ecutors).
join(other[, on, how]) Joins with another DataFrame, using the given joinexpression.
limit(num) Limits the result count to the number specified.localCheckpoint([eager]) Returns a locally checkpointed version of this
Dataset.mapInPandas(func, schema) Maps an iterator of batches in the current
DataFrame using a Python native function thattakes and outputs a pandas DataFrame, and returnsthe result as a DataFrame.
orderBy(*cols, **kwargs) Returns a new DataFrame sorted by the specifiedcolumn(s).
persist([storageLevel]) Sets the storage level to persist the contents of theDataFrame across operations after the first time itis computed.
printSchema() Prints out the schema in the tree format.randomSplit(weights[, seed]) Randomly splits this DataFrame with the provided
weights.registerTempTable(name) Registers this DataFrame as a temporary table using
the given name.repartition(numPartitions, *cols) Returns a new DataFrame partitioned by the given
partitioning expressions.repartitionByRange(numPartitions, *cols) Returns a new DataFrame partitioned by the given
partitioning expressions.replace(to_replace[, value, subset]) Returns a new DataFrame replacing a value with
another value.continues on next page
3.1. Spark SQL 33
pyspark Documentation, Release master
Table 4 – continued from previous pagerollup(*cols) Create a multi-dimensional rollup for the current
DataFrame using the specified columns, so we canrun aggregation on them.
sameSemantics(other) Returns True when the logical query plans insideboth DataFrames are equal and therefore returnsame results.
sample([withReplacement, fraction, seed]) Returns a sampled subset of this DataFrame.sampleBy(col, fractions[, seed]) Returns a stratified sample without replacement
based on the fraction given on each stratum.select(*cols) Projects a set of expressions and returns a new
DataFrame.selectExpr(*expr) Projects a set of SQL expressions and returns a new
DataFrame.semanticHash() Returns a hash code of the logical query plan against
this DataFrame.show([n, truncate, vertical]) Prints the first n rows to the console.sort(*cols, **kwargs) Returns a new DataFrame sorted by the specified
column(s).sortWithinPartitions(*cols, **kwargs) Returns a new DataFrame with each partition
sorted by the specified column(s).subtract(other) Return a new DataFrame containing rows in this
DataFrame but not in another DataFrame.summary(*statistics) Computes specified statistics for numeric and string
columns.tail(num) Returns the last num rows as a list of Row .take(num) Returns the first num rows as a list of Row .toDF(*cols) Returns a new DataFrame that with new specified
column namestoJSON ([use_unicode]) Converts a DataFrame into a RDD of string.toLocalIterator([prefetchPartitions]) Returns an iterator that contains all of the rows in this
DataFrame.toPandas() Returns the contents of this DataFrame as Pandas
pandas.DataFrame.transform(func) Returns a new DataFrame.union(other) Return a new DataFrame containing union of rows
in this and another DataFrame.unionAll(other) Return a new DataFrame containing union of rows
in this and another DataFrame.unionByName(other[, allowMissingColumns]) Returns a new DataFrame containing union of
rows in this and another DataFrame.unpersist([blocking]) Marks the DataFrame as non-persistent, and re-
move all blocks for it from memory and disk.where(condition) where() is an alias for filter().withColumn(colName, col) Returns a new DataFrame by adding a column
or replacing the existing column that has the samename.
withColumnRenamed(existing, new) Returns a new DataFrame by renaming an existingcolumn.
withWatermark(eventTime, delayThreshold) Defines an event time watermark for thisDataFrame.
writeTo(table) Create a write configuration builder for v2 sources.
34 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
columns Returns all column names as a list.dtypes Returns all column names and their data types as a
list.isStreaming Returns True if this Dataset contains one or more
sources that continuously return data as it arrives.na Returns a DataFrameNaFunctions for han-
dling missing values.rdd Returns the content as an pyspark.RDD of Row .schema Returns the schema of this DataFrame as a
pyspark.sql.types.StructType.stat Returns a DataFrameStatFunctions for
statistic functions.storageLevel Get the DataFrame’s current storage level.write Interface for saving the content of the non-streaming
DataFrame out into external storage.writeStream Interface for saving the content of the streaming
DataFrame out into external storage.
pyspark.sql.Column
class pyspark.sql.Column(jc)A column in a DataFrame.
Column instances can be created by:
# 1. Select a column out of a DataFrame
df.colNamedf["colName"]
# 2. Create from an expressiondf.colName + 11 / df.colName
New in version 1.3.0.
Methods
alias(*alias, **kwargs) Returns this column aliased with a new name ornames (in the case of expressions that return morethan one column, such as explode).
asc() Returns a sort expression based on ascending orderof the column.
asc_nulls_first() Returns a sort expression based on ascending orderof the column, and null values return before non-nullvalues.
asc_nulls_last() Returns a sort expression based on ascending orderof the column, and null values appear after non-nullvalues.
continues on next page
3.1. Spark SQL 35
pyspark Documentation, Release master
Table 6 – continued from previous pageastype(dataType) astype() is an alias for cast().between(lowerBound, upperBound) A boolean expression that is evaluated to true if
the value of this expression is between the givencolumns.
bitwiseAND(other) Compute bitwise AND of this expression with an-other expression.
bitwiseOR(other) Compute bitwise OR of this expression with anotherexpression.
bitwiseXOR(other) Compute bitwise XOR of this expression with an-other expression.
cast(dataType) Convert the column into type dataType.contains(other) Contains the other element.desc() Returns a sort expression based on the descending
order of the column.desc_nulls_first() Returns a sort expression based on the descending
order of the column, and null values appear beforenon-null values.
desc_nulls_last() Returns a sort expression based on the descendingorder of the column, and null values appear afternon-null values.
dropFields(*fieldNames) An expression that drops fields in StructType byname.
endswith(other) String ends with.eqNullSafe(other) Equality test that is safe for null values.getField(name) An expression that gets a field by name in a Struct-
Field.getItem(key) An expression that gets an item at position
ordinal out of a list, or gets an item by key outof a dict.
isNotNull() True if the current expression is NOT null.isNull() True if the current expression is null.isin(*cols) A boolean expression that is evaluated to true if the
value of this expression is contained by the evaluatedvalues of the arguments.
like(other) SQL like expression.name(*alias, **kwargs) name() is an alias for alias().otherwise(value) Evaluates a list of conditions and returns one of mul-
tiple possible result expressions.over(window) Define a windowing column.rlike(other) SQL RLIKE expression (LIKE with Regex).startswith(other) String starts with.substr(startPos, length) Return a Column which is a substring of the col-
umn.when(condition, value) Evaluates a list of conditions and returns one of mul-
tiple possible result expressions.withField(fieldName, col) An expression that adds/replaces a field in
StructType by name.
36 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.Row
class pyspark.sql.Row(*args, **kwargs)A row in DataFrame. The fields in it can be accessed:
• like attributes (row.key)
• like dictionary values (row[key])
key in row will search through row keys.
Row can be used to create a row object by using named arguments. It is not allowed to omit a named argumentto represent that the value is None or missing. This should be explicitly set to None in this case.
Changed in version 3.0.0: Rows created from named arguments no longer have field names sorted alphabeticallyand will be ordered in the position as entered.
Examples
>>> row = Row(name="Alice", age=11)>>> rowRow(name='Alice', age=11)>>> row['name'], row['age']('Alice', 11)>>> row.name, row.age('Alice', 11)>>> 'name' in rowTrue>>> 'wrong_key' in rowFalse
Row also can be used to create another Row like class, then it could be used to create Row objects, such as
>>> Person = Row("name", "age")>>> Person<Row('name', 'age')>>>> 'name' in PersonTrue>>> 'wrong_key' in PersonFalse>>> Person("Alice", 11)Row(name='Alice', age=11)
This form can also be used to create rows as tuple values, i.e. with unnamed fields.
>>> row1 = Row("Alice", 11)>>> row2 = Row(name="Alice", age=11)>>> row1 == row2True
3.1. Spark SQL 37
pyspark Documentation, Release master
Methods
asDict([recursive]) Return as a dictcount(value, /) Return number of occurrences of value.index(value[, start, stop]) Return first index of value.
pyspark.sql.GroupedData
class pyspark.sql.GroupedData(jgd, df)A set of methods for aggregations on a DataFrame, created by DataFrame.groupBy().
New in version 1.3.
Methods
agg(*exprs) Compute aggregates and returns the result as aDataFrame.
apply(udf) It is an alias of pyspark.sql.GroupedData.applyInPandas(); however, it takes apyspark.sql.functions.pandas_udf()whereas pyspark.sql.GroupedData.applyInPandas() takes a Python nativefunction.
applyInPandas(func, schema) Maps each group of the current DataFrame usinga pandas udf and returns the result as a DataFrame.
avg(*cols) Computes average values for each numeric columnsfor each group.
cogroup(other) Cogroups this group with another group so that wecan run cogrouped operations.
count() Counts the number of records for each group.max(*cols) Computes the max value for each numeric columns
for each group.mean(*cols) Computes average values for each numeric columns
for each group.min(*cols) Computes the min value for each numeric column for
each group.pivot(pivot_col[, values]) Pivots a column of the current DataFrame and per-
form the specified aggregation.sum(*cols) Compute the sum for each numeric columns for each
group.
38 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.PandasCogroupedOps
class pyspark.sql.PandasCogroupedOps(gd1, gd2)A logical grouping of two GroupedData, created by GroupedData.cogroup().
New in version 3.0.0.
Notes
This API is experimental.
Methods
applyInPandas(func, schema) Applies a function to each cogroup using pandas andreturns the result as a DataFrame.
pyspark.sql.DataFrameNaFunctions
class pyspark.sql.DataFrameNaFunctions(df)Functionality for working with missing data in DataFrame.
New in version 1.4.
Methods
drop([how, thresh, subset]) Returns a new DataFrame omitting rows with nullvalues.
fill(value[, subset]) Replace null values, alias for na.fill().replace(to_replace[, value, subset]) Returns a new DataFrame replacing a value with
another value.
pyspark.sql.DataFrameStatFunctions
class pyspark.sql.DataFrameStatFunctions(df)Functionality for statistic functions with DataFrame.
New in version 1.4.
Methods
approxQuantile(col, probabilities, relativeEr-ror)
Calculates the approximate quantiles of numericalcolumns of a DataFrame.
corr(col1, col2[, method]) Calculates the correlation of two columns of aDataFrame as a double value.
cov(col1, col2) Calculate the sample covariance for the givencolumns, specified by their names, as a double value.
continues on next page
3.1. Spark SQL 39
pyspark Documentation, Release master
Table 11 – continued from previous pagecrosstab(col1, col2) Computes a pair-wise frequency table of the given
columns.freqItems(cols[, support]) Finding frequent items for columns, possibly with
false positives.sampleBy(col, fractions[, seed]) Returns a stratified sample without replacement
based on the fraction given on each stratum.
pyspark.sql.Window
class pyspark.sql.WindowUtility functions for defining window in DataFrames.
New in version 1.4.
Notes
When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, unboundedFol-lowing) is used by default. When ordering is defined, a growing window frame (rangeFrame, unboundedPre-ceding, currentRow) is used by default.
Examples
>>> # ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW>>> window = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.→˓currentRow)
>>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING>>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3)
Methods
orderBy(*cols) Creates a WindowSpec with the ordering defined.partitionBy(*cols) Creates a WindowSpec with the partitioning de-
fined.rangeBetween(start, end) Creates a WindowSpec with the frame boundaries
defined, from start (inclusive) to end (inclusive).rowsBetween(start, end) Creates a WindowSpec with the frame boundaries
defined, from start (inclusive) to end (inclusive).
40 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
currentRowunboundedFollowingunboundedPreceding
3.1.2 Spark Session APIs
The entry point to programming Spark with the Dataset and DataFrame API. To create a Spark session, you shoulduse SparkSession.builder attribute. See also SparkSession.
SparkSession.builder.appName(name) Sets a name for the application, which will be shown inthe Spark web UI.
SparkSession.builder.config([key, value,conf])
Sets a config option.
SparkSession.builder.enableHiveSupport()
Enables Hive support, including connectivity to a per-sistent Hive metastore, support for Hive SerDes, andHive user-defined functions.
SparkSession.builder.getOrCreate() Gets an existing SparkSession or, if there is no ex-isting one, creates a new one based on the options set inthis builder.
SparkSession.builder.master(master) Sets the Spark master URL to connect to, such as “lo-cal” to run locally, “local[4]” to run locally with 4 cores,or “spark://master:7077” to run on a Spark standalonecluster.
SparkSession.catalog Interface through which the user may create, drop, alteror query underlying databases, tables, functions, etc.
SparkSession.conf Runtime configuration interface for Spark.SparkSession.createDataFrame(data[,schema, . . . ])
Creates a DataFrame from an RDD, a list or apandas.DataFrame.
SparkSession.getActiveSession() Returns the active SparkSession for the current thread,returned by the builder
SparkSession.newSession() Returns a new SparkSession as new session, that hasseparate SQLConf, registered temporary views andUDFs, but shared SparkContext and table cache.
SparkSession.range(start[, end, step, . . . ]) Create a DataFrame with single pyspark.sql.types.LongType column named id, containing el-ements in a range from start to end (exclusive) withstep value step.
SparkSession.read Returns a DataFrameReader that can be used toread data in as a DataFrame.
SparkSession.readStream Returns a DataStreamReader that can be used toread data streams as a streaming DataFrame.
SparkSession.sparkContext Returns the underlying SparkContext.SparkSession.sql(sqlQuery) Returns a DataFrame representing the result of the
given query.SparkSession.stop() Stop the underlying SparkContext.SparkSession.streams Returns a StreamingQueryManager that allows
managing all the StreamingQuery instances activeon this context.
continues on next page
3.1. Spark SQL 41
pyspark Documentation, Release master
Table 14 – continued from previous pageSparkSession.table(tableName) Returns the specified table as a DataFrame.SparkSession.udf Returns a UDFRegistration for UDF registration.SparkSession.version The version of Spark on which this application is run-
ning.
pyspark.sql.SparkSession.builder.appName
SparkSession.builder.appName(name)Sets a name for the application, which will be shown in the Spark web UI.
If no application name is set, a randomly generated name will be used.
New in version 2.0.0.
Parameters
name [str] an application name
pyspark.sql.SparkSession.builder.config
SparkSession.builder.config(key=None, value=None, conf=None)Sets a config option. Options set using this method are automatically propagated to both SparkConf andSparkSession’s own configuration.
New in version 2.0.0.
Parameters
key [str, optional] a key name string for configuration property
value [str, optional] a value for configuration property
conf [SparkConf, optional] an instance of SparkConf
Examples
For an existing SparkConf, use conf parameter.
>>> from pyspark.conf import SparkConf>>> SparkSession.builder.config(conf=SparkConf())<pyspark.sql.session...
For a (key, value) pair, you can omit parameter names.
>>> SparkSession.builder.config("spark.some.config.option", "some-value")<pyspark.sql.session...
42 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.SparkSession.builder.enableHiveSupport
SparkSession.builder.enableHiveSupport()Enables Hive support, including connectivity to a persistent Hive metastore, support for Hive SerDes, and Hiveuser-defined functions.
New in version 2.0.
pyspark.sql.SparkSession.builder.getOrCreate
SparkSession.builder.getOrCreate()Gets an existing SparkSession or, if there is no existing one, creates a new one based on the options set inthis builder.
New in version 2.0.0.
Examples
This method first checks whether there is a valid global default SparkSession, and if yes, return that one. If novalid global default SparkSession exists, the method creates a new SparkSession and assigns the newly createdSparkSession as the global default.
>>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate()>>> s1.conf.get("k1") == "v1"True
In case an existing SparkSession is returned, the config options specified in this builder will be applied to theexisting SparkSession.
>>> s2 = SparkSession.builder.config("k2", "v2").getOrCreate()>>> s1.conf.get("k1") == s2.conf.get("k1")True>>> s1.conf.get("k2") == s2.conf.get("k2")True
pyspark.sql.SparkSession.builder.master
SparkSession.builder.master(master)Sets the Spark master URL to connect to, such as “local” to run locally, “local[4]” to run locally with 4 cores,or “spark://master:7077” to run on a Spark standalone cluster.
New in version 2.0.0.
Parameters
master [str] a url for spark master
3.1. Spark SQL 43
pyspark Documentation, Release master
pyspark.sql.SparkSession.catalog
property SparkSession.catalogInterface through which the user may create, drop, alter or query underlying databases, tables, functions, etc.
New in version 2.0.0.
Returns
Catalog
pyspark.sql.SparkSession.conf
property SparkSession.confRuntime configuration interface for Spark.
This is the interface through which the user can get and set all Spark and Hadoop configurations that arerelevant to Spark SQL. When getting the value of a config, this defaults to the value set in the underlyingSparkContext, if any.
New in version 2.0.
pyspark.sql.SparkSession.createDataFrame
SparkSession.createDataFrame(data, schema=None, samplingRatio=None, verifySchema=True)Creates a DataFrame from an RDD, a list or a pandas.DataFrame.
When schema is a list of column names, the type of each column will be inferred from data.
When schema is None, it will try to infer the schema (column names and types) from data, which should bean RDD of either Row , namedtuple, or dict.
When schema is pyspark.sql.types.DataType or a datatype string, it must match the real data, oran exception will be thrown at runtime. If the given schema is not pyspark.sql.types.StructType,it will be wrapped into a pyspark.sql.types.StructType as its only field, and the field name will be“value”. Each record will also be wrapped into a tuple, which can be converted to row later.
If schema inference is needed, samplingRatio is used to determined the ratio of rows used for schemainference. The first row will be used if samplingRatio is None.
New in version 2.0.0.
Changed in version 2.1.0: Added verifySchema.
Parameters
data [RDD or iterable] an RDD of any kind of SQL data representation(e.g. Row , tuple, int,boolean, etc.), or list, or pandas.DataFrame.
schema [pyspark.sql.types.DataType, str or list, optional] a pyspark.sql.types.DataType or a datatype string or a list of column names, default isNone. The data type string format equals to pyspark.sql.types.DataType.simpleString, except that top level struct type can omit the struct<> and atomictypes use typeName() as their format, e.g. use byte instead of tinyint forpyspark.sql.types.ByteType. We can also use int as a short name forpyspark.sql.types.IntegerType.
samplingRatio [float, optional] the sample ratio of rows used for inferring
verifySchema [bool, optional] verify data types of every row against schema. Enabled by de-fault.
44 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
DataFrame
Notes
Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
Examples
>>> l = [('Alice', 1)]>>> spark.createDataFrame(l).collect()[Row(_1='Alice', _2=1)]>>> spark.createDataFrame(l, ['name', 'age']).collect()[Row(name='Alice', age=1)]
>>> d = [{'name': 'Alice', 'age': 1}]>>> spark.createDataFrame(d).collect()[Row(age=1, name='Alice')]
>>> rdd = sc.parallelize(l)>>> spark.createDataFrame(rdd).collect()[Row(_1='Alice', _2=1)]>>> df = spark.createDataFrame(rdd, ['name', 'age'])>>> df.collect()[Row(name='Alice', age=1)]
>>> from pyspark.sql import Row>>> Person = Row('name', 'age')>>> person = rdd.map(lambda r: Person(*r))>>> df2 = spark.createDataFrame(person)>>> df2.collect()[Row(name='Alice', age=1)]
>>> from pyspark.sql.types import *>>> schema = StructType([... StructField("name", StringType(), True),... StructField("age", IntegerType(), True)])>>> df3 = spark.createDataFrame(rdd, schema)>>> df3.collect()[Row(name='Alice', age=1)]
>>> spark.createDataFrame(df.toPandas()).collect()[Row(name='Alice', age=1)]>>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect()[Row(0=1, 1=2)]
>>> spark.createDataFrame(rdd, "a: string, b: int").collect()[Row(a='Alice', b=1)]>>> rdd = rdd.map(lambda row: row[1])>>> spark.createDataFrame(rdd, "int").collect()[Row(value=1)]>>> spark.createDataFrame(rdd, "boolean").collect()
(continues on next page)
3.1. Spark SQL 45
pyspark Documentation, Release master
(continued from previous page)
Traceback (most recent call last):...
Py4JJavaError: ...
pyspark.sql.SparkSession.getActiveSession
classmethod SparkSession.getActiveSession()Returns the active SparkSession for the current thread, returned by the builder
New in version 3.0.0.
Returns
SparkSession Spark session if an active session exists for the current thread
Examples
>>> s = SparkSession.getActiveSession()>>> l = [('Alice', 1)]>>> rdd = s.sparkContext.parallelize(l)>>> df = s.createDataFrame(rdd, ['name', 'age'])>>> df.select("age").collect()[Row(age=1)]
pyspark.sql.SparkSession.newSession
SparkSession.newSession()Returns a new SparkSession as new session, that has separate SQLConf, registered temporary views and UDFs,but shared SparkContext and table cache.
New in version 2.0.
pyspark.sql.SparkSession.range
SparkSession.range(start, end=None, step=1, numPartitions=None)Create a DataFrame with single pyspark.sql.types.LongType column named id, containing ele-ments in a range from start to end (exclusive) with step value step.
New in version 2.0.0.
Parameters
start [int] the start value
end [int, optional] the end value (exclusive)
step [int, optional] the incremental step (default: 1)
numPartitions [int, optional] the number of partitions of the DataFrame
Returns
DataFrame
46 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> spark.range(1, 7, 2).collect()[Row(id=1), Row(id=3), Row(id=5)]
If only one argument is specified, it will be used as the end value.
>>> spark.range(3).collect()[Row(id=0), Row(id=1), Row(id=2)]
pyspark.sql.SparkSession.read
property SparkSession.readReturns a DataFrameReader that can be used to read data in as a DataFrame.
New in version 2.0.0.
Returns
DataFrameReader
pyspark.sql.SparkSession.readStream
property SparkSession.readStreamReturns a DataStreamReader that can be used to read data streams as a streaming DataFrame.
New in version 2.0.0.
Returns
DataStreamReader
Notes
This API is evolving.
pyspark.sql.SparkSession.sparkContext
property SparkSession.sparkContextReturns the underlying SparkContext.
New in version 2.0.
pyspark.sql.SparkSession.sql
SparkSession.sql(sqlQuery)Returns a DataFrame representing the result of the given query.
New in version 2.0.0.
Returns
DataFrame
3.1. Spark SQL 47
pyspark Documentation, Release master
Examples
>>> df.createOrReplaceTempView("table1")>>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")>>> df2.collect()[Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')]
pyspark.sql.SparkSession.stop
SparkSession.stop()Stop the underlying SparkContext.
New in version 2.0.
pyspark.sql.SparkSession.streams
property SparkSession.streamsReturns a StreamingQueryManager that allows managing all the StreamingQuery instances active onthis context.
New in version 2.0.0.
Returns
StreamingQueryManager
Notes
This API is evolving.
pyspark.sql.SparkSession.table
SparkSession.table(tableName)Returns the specified table as a DataFrame.
New in version 2.0.0.
Returns
DataFrame
Examples
>>> df.createOrReplaceTempView("table1")>>> df2 = spark.table("table1")>>> sorted(df.collect()) == sorted(df2.collect())True
48 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.SparkSession.udf
property SparkSession.udfReturns a UDFRegistration for UDF registration.
New in version 2.0.0.
Returns
UDFRegistration
pyspark.sql.SparkSession.version
property SparkSession.versionThe version of Spark on which this application is running.
New in version 2.0.
3.1.3 Input and Output
DataFrameReader.csv(path[, schema, sep, . . . ]) Loads a CSV file and returns the result as aDataFrame.
DataFrameReader.format(source) Specifies the input data source format.DataFrameReader.jdbc(url, table[, column, . . . ]) Construct a DataFrame representing the database ta-
ble named table accessible via JDBC URL url andconnection properties.
DataFrameReader.json(path[, schema, . . . ]) Loads JSON files and returns the results as aDataFrame.
DataFrameReader.load([path, format, schema]) Loads data from a data source and returns it as aDataFrame.
DataFrameReader.option(key, value) Adds an input option for the underlying data source.DataFrameReader.options(**options) Adds input options for the underlying data source.DataFrameReader.orc(path[, mergeSchema, . . . ]) Loads ORC files, returning the result as a DataFrame.DataFrameReader.parquet(*paths, **options) Loads Parquet files, returning the result as a
DataFrame.DataFrameReader.schema(schema) Specifies the input schema.DataFrameReader.table(tableName) Returns the specified table as a DataFrame.DataFrameWriter.bucketBy(numBuckets, col,*cols)
Buckets the output by the given columns.If specified,the output is laid out on the file system similar to Hive’sbucketing scheme.
DataFrameWriter.csv(path[, mode, . . . ]) Saves the content of the DataFrame in CSV format atthe specified path.
DataFrameWriter.format(source) Specifies the underlying output data source.DataFrameWriter.insertInto(tableName[,. . . ])
Inserts the content of the DataFrame to the specifiedtable.
DataFrameWriter.jdbc(url, table[, mode, . . . ]) Saves the content of the DataFrame to an externaldatabase table via JDBC.
DataFrameWriter.json(path[, mode, . . . ]) Saves the content of the DataFrame in JSON format(JSON Lines text format or newline-delimited JSON) atthe specified path.
DataFrameWriter.mode(saveMode) Specifies the behavior when data or table already exists.DataFrameWriter.option(key, value) Adds an output option for the underlying data source.
continues on next page
3.1. Spark SQL 49
pyspark Documentation, Release master
Table 15 – continued from previous pageDataFrameWriter.options(**options) Adds output options for the underlying data source.DataFrameWriter.orc(path[, mode, . . . ]) Saves the content of the DataFrame in ORC format at
the specified path.DataFrameWriter.parquet(path[, mode, . . . ]) Saves the content of the DataFrame in Parquet format
at the specified path.DataFrameWriter.partitionBy(*cols) Partitions the output by the given columns on the file
system.DataFrameWriter.save([path, format, mode,. . . ])
Saves the contents of the DataFrame to a data source.
DataFrameWriter.saveAsTable(name[, for-mat, . . . ])
Saves the content of the DataFrame as the specifiedtable.
DataFrameWriter.sortBy(col, *cols) Sorts the output in each bucket by the given columns onthe file system.
DataFrameWriter.text(path[, compression, . . . ]) Saves the content of the DataFrame in a text file at thespecified path.
pyspark.sql.DataFrameReader.csv
DataFrameReader.csv(path, schema=None, sep=None, encoding=None, quote=None, escape=None,comment=None, header=None, inferSchema=None, ignoreLeadingWhiteS-pace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nan-Value=None, positiveInf=None, negativeInf=None, dateFormat=None,timestampFormat=None, maxColumns=None, maxCharsPerColumn=None,maxMalformedLogPerPartition=None, mode=None, columnNameOfCor-ruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None,lineSep=None, pathGlobFilter=None, recursiveFileLookup=None, modifiedBe-fore=None, modifiedAfter=None, unescapedQuoteHandling=None)
Loads a CSV file and returns the result as a DataFrame.
This function will go through the input once to determine the input schema if inferSchema is enabled. Toavoid going through the entire data once, disable inferSchema option or specify the schema explicitly usingschema.
New in version 2.0.0.
Parameters
path [str or list] string, or list of strings, for input path(s), or RDD of Strings storing CSV rows.
schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).
sep [str, optional] sets a separator (one or more characters) for each field and value. If None isset, it uses the default value, ,.
encoding [str, optional] decodes the CSV files by the given encoding type. If None is set, ituses the default value, UTF-8.
quote [str, optional] sets a single character used for escaping quoted values where the separatorcan be part of the value. If None is set, it uses the default value, ". If you would like to turnoff quotations, you need to set an empty string.
escape [str, optional] sets a single character used for escaping quotes inside an already quotedvalue. If None is set, it uses the default value, \.
50 Chapter 3. API Reference
pyspark Documentation, Release master
comment [str, optional] sets a single character used for skipping lines beginning with this char-acter. By default (None), it is disabled.
header [str or bool, optional] uses the first line as names of columns. If None is set, it uses thedefault value, false.
Note: if the given path is a RDD of Strings, this header option will remove all lines samewith the header if exists.
inferSchema [str or bool, optional] infers the input schema automatically from data. It requiresone extra pass over the data. If None is set, it uses the default value, false.
enforceSchema [str or bool, optional] If it is set to true, the specified or inferred schemawill be forcibly applied to datasource files, and headers in CSV files will be ignored. If theoption is set to false, the schema will be validated against all headers in CSV files or thefirst header in RDD if the header option is set to true. Field names in the schema andcolumn names in CSV headers are checked by their positions taking into account spark.sql.caseSensitive. If None is set, true is used by default. Though the default valueis true, it is recommended to disable the enforceSchema option to avoid incorrectresults.
ignoreLeadingWhiteSpace [str or bool, optional] A flag indicating whether or not leadingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.
ignoreTrailingWhiteSpace [str or bool, optional] A flag indicating whether or not trailingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.
nullValue [str, optional] sets the string representation of a null value. If None is set, it uses thedefault value, empty string. Since 2.0.1, this nullValue param applies to all supportedtypes including the string type.
nanValue [str, optional] sets the string representation of a non-number value. If None is set, ituses the default value, NaN.
positiveInf [str, optional] sets the string representation of a positive infinity value. If None isset, it uses the default value, Inf.
negativeInf [str, optional] sets the string representation of a negative infinity value. If None isset, it uses the default value, Inf.
dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.
timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].
maxColumns [str or int, optional] defines a hard limit of how many columns a record can have.If None is set, it uses the default value, 20480.
maxCharsPerColumn [str or int, optional] defines the maximum number of characters allowedfor any given value being read. If None is set, it uses the default value, -1meaning unlimitedlength.
maxMalformedLogPerPartition [str or int, optional] this parameter is no longer used sinceSpark 2.2.0. If specified, it is ignored.
3.1. Spark SQL 51
pyspark Documentation, Release master
mode [str, optional] allows a mode for dealing with corrupt records during parsing. If Noneis set, it uses the default value, PERMISSIVE. Note that Spark tries to parse only requiredcolumns in CSV under column pruning. Therefore, corrupt records can be different basedon required set of fields. This behavior can be controlled by spark.sql.csv.parser.columnPruning.enabled (enabled by default).
• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. A record with less/more tokensthan schema is not a corrupted record to CSV. When it meets a record having fewer to-kens than the length of the schema, sets null to extra fields. When the record has moretokens than the length of the schema, it drops extra tokens.
• DROPMALFORMED: ignores the whole corrupted records.
• FAILFAST: throws an exception when it meets corrupted records.
columnNameOfCorruptRecord [str, optional] allows renaming the new field havingmalformed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.
multiLine [str or bool, optional] parse records, which may span multiple lines. If None is set,it uses the default value, false.
charToEscapeQuoteEscaping [str, optional] sets a single character used for escaping the es-cape for the quote character. If None is set, the default value is escape character when escapeand quote characters are different, \0 otherwise.
samplingRatio [str or float, optional] defines fraction of rows used for schema inferring. IfNone is set, it uses the default value, 1.0.
emptyValue [str, optional] sets the string representation of an empty value. If None is set, ituses the default value, empty string.
locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.
lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \\r, \\r\\n and \\n. Maximum length is 1 character.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
modification times occurring before the specified time. The provided timestamp must be inthe following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedBefore (batch only) [an optional timestamp to only include files with] modificationtimes occurring before the specified time. The provided timestamp must be in the followingformat: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedAfter (batch only) [an optional timestamp to only include files with] modificationtimes occurring after the specified time. The provided timestamp must be in the follow-ing format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
52 Chapter 3. API Reference
pyspark Documentation, Release master
unescapedQuoteHandling [str, optional] defines how the CsvParser will handle values withunescaped quotes. If None is set, it uses the default value, STOP_AT_DELIMITER.
• STOP_AT_CLOSING_QUOTE: If unescaped quotes are found in the input, accumulatethe quote character and proceed parsing the value as a quoted value, until a closing quoteis found.
• BACK_TO_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters of the currentparsed value until the delimiter is found. If no delimiter is found in the value, the parserwill continue accumulating characters from the input until a delimiter or line ending isfound.
• STOP_AT_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters until the de-limiter or a line ending is found in the input.
• STOP_AT_DELIMITER: If unescaped quotes are found in the input, the content parsedfor the given value will be skipped and the value set in nullValue will be produced instead.
• RAISE_ERROR: If unescaped quotes are found in the input, a TextParsingException willbe thrown.
Examples
>>> df = spark.read.csv('python/test_support/sql/ages.csv')>>> df.dtypes[('_c0', 'string'), ('_c1', 'string')]>>> rdd = sc.textFile('python/test_support/sql/ages.csv')>>> df2 = spark.read.csv(rdd)>>> df2.dtypes[('_c0', 'string'), ('_c1', 'string')]
pyspark.sql.DataFrameReader.format
DataFrameReader.format(source)Specifies the input data source format.
New in version 1.4.0.
Parameters
source [str] string, name of the data source, e.g. ‘json’, ‘parquet’.
Examples
>>> df = spark.read.format('json').load('python/test_support/sql/people.json')>>> df.dtypes[('age', 'bigint'), ('name', 'string')]
3.1. Spark SQL 53
pyspark Documentation, Release master
pyspark.sql.DataFrameReader.jdbc
DataFrameReader.jdbc(url, table, column=None, lowerBound=None, upperBound=None, numParti-tions=None, predicates=None, properties=None)
Construct a DataFrame representing the database table named table accessible via JDBC URL url andconnection properties.
Partitions of the table will be retrieved in parallel if either column or predicates is specified.lowerBound, upperBound and numPartitions is needed when column is specified.
If both column and predicates are specified, column will be used.
New in version 1.4.0.
Parameters
url [str] a JDBC URL of the form jdbc:subprotocol:subname
table [str] the name of the table
column [str, optional] the name of a column of numeric, date, or timestamp type that will beused for partitioning; if this parameter is specified, then numPartitions, lowerBound(inclusive), and upperBound (exclusive) will form partition strides for generated WHEREclause expressions used to split the column column evenly
lowerBound [str or int, optional] the minimum value of column used to decide partition stride
upperBound [str or int, optional] the maximum value of column used to decide partition stride
numPartitions [int, optional] the number of partitions
predicates [list, optional] a list of expressions suitable for inclusion in WHERE clauses; eachone defines one partition of the DataFrame
properties [dict, optional] a dictionary of JDBC database connection arguments. Normallyat least properties “user” and “password” with their corresponding values. For example {‘user’ : ‘SYSTEM’, ‘password’ : ‘mypassword’ }
Returns
DataFrame
Notes
Don’t create too many partitions in parallel on a large cluster; otherwise Spark might crash your externaldatabase systems.
pyspark.sql.DataFrameReader.json
DataFrameReader.json(path, schema=None, primitivesAsString=None, prefersDecimal=None,allowComments=None, allowUnquotedFieldNames=None, allowSingle-Quotes=None, allowNumericLeadingZero=None, allowBackslashEscapin-gAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None,dateFormat=None, timestampFormat=None, multiLine=None, allowUnquot-edControlChars=None, lineSep=None, samplingRatio=None, dropFieldIfAll-Null=None, encoding=None, locale=None, pathGlobFilter=None, recursive-FileLookup=None, allowNonNumericNumbers=None, modifiedBefore=None,modifiedAfter=None)
Loads JSON files and returns the results as a DataFrame.
54 Chapter 3. API Reference
pyspark Documentation, Release master
JSON Lines (newline-delimited JSON) is supported by default. For JSON (one record per file), set themultiLine parameter to true.
If the schema parameter is not specified, this function goes through the input once to determine the inputschema.
New in version 1.4.0.
Parameters
path [str, list or RDD] string represents path to the JSON dataset, or a list of paths, or RDD ofStrings storing JSON objects.
schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).
primitivesAsString [str or bool, optional] infers all primitive values as a string type. If None isset, it uses the default value, false.
prefersDecimal [str or bool, optional] infers all floating-point values as a decimal type. If thevalues do not fit in decimal, then it infers them as doubles. If None is set, it uses the defaultvalue, false.
allowComments [str or bool, optional] ignores Java/C++ style comment in JSON records. IfNone is set, it uses the default value, false.
allowUnquotedFieldNames [str or bool, optional] allows unquoted JSON field names. If Noneis set, it uses the default value, false.
allowSingleQuotes [str or bool, optional] allows single quotes in addition to double quotes. IfNone is set, it uses the default value, true.
allowNumericLeadingZero [str or bool, optional] allows leading zeros in numbers (e.g.00012). If None is set, it uses the default value, false.
allowBackslashEscapingAnyCharacter [str or bool, optional] allows accepting quoting of allcharacter using backslash quoting mechanism. If None is set, it uses the default value,false.
mode [str, optional]
allows a mode for dealing with corrupt records during parsing. If None is set, it usesthe default value, PERMISSIVE.
• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. When inferring a schema, it im-plicitly adds a columnNameOfCorruptRecord field in an output schema.
• DROPMALFORMED: ignores the whole corrupted records.
• FAILFAST: throws an exception when it meets corrupted records.
columnNameOfCorruptRecord: str, optional allows renaming the new field having mal-formed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.
3.1. Spark SQL 55
pyspark Documentation, Release master
dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.
timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].
multiLine [str or bool, optional] parse one record, which may span multiple lines, per file. IfNone is set, it uses the default value, false.
allowUnquotedControlChars [str or bool, optional] allows JSON Strings to contain unquotedcontrol characters (ASCII characters with value less than 32, including tab and line feedcharacters) or not.
encoding [str or bool, optional] allows to forcibly set one of standard basic or extended encod-ing for the JSON files. For example UTF-16BE, UTF-32LE. If None is set, the encoding ofinput JSON will be detected automatically when the multiLine option is set to true.
lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \r, \r\n and \n.
samplingRatio [str or float, optional] defines fraction of input JSON objects used for schemainferring. If None is set, it uses the default value, 1.0.
dropFieldIfAllNull [str or bool, optional] whether to ignore column of all null values or emptyarray/struct during schema inference. If None is set, it uses the default value, false.
locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
allowNonNumericNumbers [str or bool] allows JSON parser to recognize set of “Not-a-Number” (NaN) tokens as legal floating number values. If None is set, it uses the defaultvalue, true.
• +INF: for positive infinity, as well as alias of +Infinity and Infinity.
• -INF: for negative infinity, alias -Infinity.
• NaN: for other not-a-numbers, like result of division by zero.
modifiedBefore [an optional timestamp to only include files with] modification times occurringbefore the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedAfter [an optional timestamp to only include files with] modification times occurringafter the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
56 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df1 = spark.read.json('python/test_support/sql/people.json')>>> df1.dtypes[('age', 'bigint'), ('name', 'string')]>>> rdd = sc.textFile('python/test_support/sql/people.json')>>> df2 = spark.read.json(rdd)>>> df2.dtypes[('age', 'bigint'), ('name', 'string')]
pyspark.sql.DataFrameReader.load
DataFrameReader.load(path=None, format=None, schema=None, **options)Loads data from a data source and returns it as a DataFrame.
New in version 1.4.0.
Parameters
path [str or list, optional] optional string or a list of string for file-system backed data sources.
format [str, optional] optional string for format of the data source. Default to ‘parquet’.
schema [pyspark.sql.types.StructType or str, optional] optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For example col0INT, col1 DOUBLE).
**options [dict] all other string options
Examples
>>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_→˓partitioned',... opt1=True, opt2=1, opt3='str')>>> df.dtypes[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
>>> df = spark.read.format('json').load(['python/test_support/sql/people.json',... 'python/test_support/sql/people1.json'])>>> df.dtypes[('age', 'bigint'), ('aka', 'string'), ('name', 'string')]
pyspark.sql.DataFrameReader.option
DataFrameReader.option(key, value)Adds an input option for the underlying data source.
You can set the following option(s) for reading files:
• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
3.1. Spark SQL 57
pyspark Documentation, Release master
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
• pathGlobFilter: an optional glob pattern to only include files with paths matching the pat-tern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior ofpartition discovery.
• modifiedBefore: an optional timestamp to only include files with modification times occur-ring before the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
• modifiedAfter: an optional timestamp to only include files with modification times occur-ring after the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
New in version 1.5.
pyspark.sql.DataFrameReader.options
DataFrameReader.options(**options)Adds input options for the underlying data source.
You can set the following option(s) for reading files:
• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
• pathGlobFilter: an optional glob pattern to only include files with paths matching the pat-tern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior ofpartition discovery.
• modifiedBefore: an optional timestamp to only include files with modification times occur-ring before the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
• modifiedAfter: an optional timestamp to only include files with modification times occur-ring after the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
New in version 1.4.
58 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrameReader.orc
DataFrameReader.orc(path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None,modifiedBefore=None, modifiedAfter=None)
Loads ORC files, returning the result as a DataFrame.
New in version 1.5.0.
Parameters
path [str or list]
mergeSchema [str or bool, optional] sets whether we should merge schemas collected from allORC part-files. This will override spark.sql.orc.mergeSchema. The default valueis specified in spark.sql.orc.mergeSchema.
pathGlobFilter [str or bool] an optional glob pattern to only include files with paths matchingthe pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change thebehavior of partition discovery. # noqa
recursiveFileLookup [str or bool] recursively scan a directory for files. Using this option dis-ables partition discovery. # noqa
modification times occurring before the specified time. The provided timestamp must be inthe following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedBefore [an optional timestamp to only include files with] modification times occurringbefore the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedAfter [an optional timestamp to only include files with] modification times occurringafter the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
Examples
>>> df = spark.read.orc('python/test_support/sql/orc_partitioned')>>> df.dtypes[('a', 'bigint'), ('b', 'int'), ('c', 'int')]
pyspark.sql.DataFrameReader.parquet
DataFrameReader.parquet(*paths, **options)Loads Parquet files, returning the result as a DataFrame.
New in version 1.4.0.
Parameters
paths [str]
Other Parameters
mergeSchema [str or bool, optional] sets whether we should merge schemas collected fromall Parquet part-files. This will override spark.sql.parquet.mergeSchema. Thedefault value is specified in spark.sql.parquet.mergeSchema.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa
3.1. Spark SQL 59
pyspark Documentation, Release master
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
modification times occurring before the specified time. The provided timestamp must be inthe following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedBefore (batch only) [an optional timestamp to only include files with] modificationtimes occurring before the specified time. The provided timestamp must be in the followingformat: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
modifiedAfter (batch only) [an optional timestamp to only include files with] modificationtimes occurring after the specified time. The provided timestamp must be in the follow-ing format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)
Examples
>>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')>>> df.dtypes[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
pyspark.sql.DataFrameReader.schema
DataFrameReader.schema(schema)Specifies the input schema.
Some data sources (e.g. JSON) can infer the input schema automatically from data. By specifying the schemahere, the underlying data source can skip the schema inference step, and thus speed up data loading.
New in version 1.4.0.
Parameters
schema [pyspark.sql.types.StructType or str] a pyspark.sql.types.StructType object or a DDL-formatted string (For example col0 INT, col1DOUBLE).
>>> s = spark.read.schema(“col0 INT, col1 DOUBLE”)
pyspark.sql.DataFrameReader.table
DataFrameReader.table(tableName)Returns the specified table as a DataFrame.
New in version 1.4.0.
Parameters
tableName [str] string, name of the table.
60 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')>>> df.createOrReplaceTempView('tmpTable')>>> spark.read.table('tmpTable').dtypes[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
pyspark.sql.DataFrameWriter.bucketBy
DataFrameWriter.bucketBy(numBuckets, col, *cols)Buckets the output by the given columns.If specified, the output is laid out on the file system similar to Hive’sbucketing scheme.
New in version 2.3.0.
Parameters
numBuckets [int] the number of buckets to save
col [str, list or tuple] a name of a column, or a list of names.
cols [str] additional names (optional). If col is a list it should be empty.
Notes
Applicable for file-based data sources in combination with DataFrameWriter.saveAsTable().
Examples
>>> (df.write.format('parquet')... .bucketBy(100, 'year', 'month')... .mode("overwrite")... .saveAsTable('bucketed_table'))
pyspark.sql.DataFrameWriter.csv
DataFrameWriter.csv(path, mode=None, compression=None, sep=None, quote=None, escape=None,header=None, nullValue=None, escapeQuotes=None, quoteAll=None, date-Format=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ig-noreTrailingWhiteSpace=None, charToEscapeQuoteEscaping=None, encod-ing=None, emptyValue=None, lineSep=None)
Saves the content of the DataFrame in CSV format at the specified path.
New in version 2.0.0.
Parameters
path [str] the path in any Hadoop supported file system
mode [str, optional] specifies the behavior of the save operation when data already exists.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• ignore: Silently ignore this operation if data already exists.
3.1. Spark SQL 61
pyspark Documentation, Release master
• error or errorifexists (default case): Throw an exception if data alreadyexists.
compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate).
sep [str, optional] sets a separator (one or more characters) for each field and value. If None isset, it uses the default value, ,.
quote [str, optional] sets a single character used for escaping quoted values where the separatorcan be part of the value. If None is set, it uses the default value, ". If an empty string is set,it uses u0000 (null character).
escape [str, optional] sets a single character used for escaping quotes inside an already quotedvalue. If None is set, it uses the default value, \
escapeQuotes [str or bool, optional] a flag indicating whether values containing quotes shouldalways be enclosed in quotes. If None is set, it uses the default value true, escaping allvalues containing a quote character.
quoteAll [str or bool, optional] a flag indicating whether all values should always be enclosedin quotes. If None is set, it uses the default value false, only escaping values containing aquote character.
header [str or bool, optional] writes the names of columns as the first line. If None is set, ituses the default value, false.
nullValue [str, optional] sets the string representation of a null value. If None is set, it uses thedefault value, empty string.
dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.
timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].
ignoreLeadingWhiteSpace [str or bool, optional] a flag indicating whether or not leadingwhitespaces from values being written should be skipped. If None is set, it uses the de-fault value, true.
ignoreTrailingWhiteSpace [str or bool, optional] a flag indicating whether or not trailingwhitespaces from values being written should be skipped. If None is set, it uses the de-fault value, true.
charToEscapeQuoteEscaping [str, optional] sets a single character used for escaping the es-cape for the quote character. If None is set, the default value is escape character when escapeand quote characters are different, \0 otherwise..
encoding [str, optional] sets the encoding (charset) of saved csv files. If None is set, the defaultUTF-8 charset will be used.
emptyValue [str, optional] sets the string representation of an empty value. If None is set, ituses the default value, "".
lineSep [str, optional] defines the line separator that should be used for writing. If None is set,it uses the default value, \\n. Maximum length is 1 character.
62 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
pyspark.sql.DataFrameWriter.format
DataFrameWriter.format(source)Specifies the underlying output data source.
New in version 1.4.0.
Parameters
source [str] string, name of the data source, e.g. ‘json’, ‘parquet’.
Examples
>>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data'))
pyspark.sql.DataFrameWriter.insertInto
DataFrameWriter.insertInto(tableName, overwrite=None)Inserts the content of the DataFrame to the specified table.
It requires that the schema of the DataFrame is the same as the schema of the table.
Optionally overwriting any existing data.
New in version 1.4.
pyspark.sql.DataFrameWriter.jdbc
DataFrameWriter.jdbc(url, table, mode=None, properties=None)Saves the content of the DataFrame to an external database table via JDBC.
New in version 1.4.0.
Parameters
url [str] a JDBC URL of the form jdbc:subprotocol:subname
table [str] Name of the table in the external database.
mode [str, optional] specifies the behavior of the save operation when data already exists.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• ignore: Silently ignore this operation if data already exists.
• error or errorifexists (default case): Throw an exception if data already exists.
properties [dict] a dictionary of JDBC database connection arguments. Normally at least prop-erties “user” and “password” with their corresponding values. For example { ‘user’ : ‘SYS-TEM’, ‘password’ : ‘mypassword’ }
3.1. Spark SQL 63
pyspark Documentation, Release master
Notes
Don’t create too many partitions in parallel on a large cluster; otherwise Spark might crash your externaldatabase systems.
pyspark.sql.DataFrameWriter.json
DataFrameWriter.json(path, mode=None, compression=None, dateFormat=None, timestampFor-mat=None, lineSep=None, encoding=None, ignoreNullFields=None)
Saves the content of the DataFrame in JSON format (JSON Lines text format or newline-delimited JSON) atthe specified path.
New in version 1.4.0.
Parameters
path [str] the path in any Hadoop supported file system
mode [str, optional] specifies the behavior of the save operation when data already exists.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• ignore: Silently ignore this operation if data already exists.
• error or errorifexists (default case): Throw an exception if data already exists.
compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate).
dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.
timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].
encoding [str, optional] specifies encoding (charset) of saved json files. If None is set, thedefault UTF-8 charset will be used.
lineSep [str, optional defines the line separator that should be used for writing. If None is] set,it uses the default value, \n.
ignoreNullFields [str or bool, optional] Whether to ignore null fields when generating JSONobjects. If None is set, it uses the default value, true.
Examples
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
64 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrameWriter.mode
DataFrameWriter.mode(saveMode)Specifies the behavior when data or table already exists.
Options include:
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• error or errorifexists: Throw an exception if data already exists.
• ignore: Silently ignore this operation if data already exists.
New in version 1.4.0.
Examples
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
pyspark.sql.DataFrameWriter.option
DataFrameWriter.option(key, value)Adds an output option for the underlying data source.
You can set the following option(s) for writing files:
• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
New in version 1.5.
pyspark.sql.DataFrameWriter.options
DataFrameWriter.options(**options)Adds output options for the underlying data source.
You can set the following option(s) for writing files:
• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
3.1. Spark SQL 65
pyspark Documentation, Release master
New in version 1.4.
pyspark.sql.DataFrameWriter.orc
DataFrameWriter.orc(path, mode=None, partitionBy=None, compression=None)Saves the content of the DataFrame in ORC format at the specified path.
New in version 1.5.0.
Parameters
path [str] the path in any Hadoop supported file system
mode [str, optional] specifies the behavior of the save operation when data already exists.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• ignore: Silently ignore this operation if data already exists.
• error or errorifexists (default case): Throw an exception if data already exists.
partitionBy [str or list, optional] names of partitioning columns
compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, snappy, zlib, and lzo). This will overrideorc.compress and spark.sql.orc.compression.codec. If None is set, it usesthe value specified in spark.sql.orc.compression.codec.
Examples
>>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned')>>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
pyspark.sql.DataFrameWriter.parquet
DataFrameWriter.parquet(path, mode=None, partitionBy=None, compression=None)Saves the content of the DataFrame in Parquet format at the specified path.
New in version 1.4.0.
Parameters
path [str] the path in any Hadoop supported file system
mode [str, optional] specifies the behavior of the save operation when data already exists.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• ignore: Silently ignore this operation if data already exists.
• error or errorifexists (default case): Throw an exception if data already exists.
partitionBy [str or list, optional] names of partitioning columns
66 Chapter 3. API Reference
pyspark Documentation, Release master
compression [str, optional] compression codec to use when saving to file. This can be one of theknown case-insensitive shorten names (none, uncompressed, snappy, gzip, lzo, brotli, lz4,and zstd). This will override spark.sql.parquet.compression.codec. If Noneis set, it uses the value specified in spark.sql.parquet.compression.codec.
Examples
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
pyspark.sql.DataFrameWriter.partitionBy
DataFrameWriter.partitionBy(*cols)Partitions the output by the given columns on the file system.
If specified, the output is laid out on the file system similar to Hive’s partitioning scheme.
New in version 1.4.0.
Parameters
cols [str or list] name of columns
Examples
>>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(),→˓ 'data'))
pyspark.sql.DataFrameWriter.save
DataFrameWriter.save(path=None, format=None, mode=None, partitionBy=None, **options)Saves the contents of the DataFrame to a data source.
The data source is specified by the format and a set of options. If format is not specified, the default datasource configured by spark.sql.sources.default will be used.
New in version 1.4.0.
Parameters
path [str, optional] the path in a Hadoop supported file system
format [str, optional] the format used to save
mode [str, optional] specifies the behavior of the save operation when data already exists.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• ignore: Silently ignore this operation if data already exists.
• error or errorifexists (default case): Throw an exception if data already exists.
partitionBy [list, optional] names of partitioning columns
**options [dict] all other string options
3.1. Spark SQL 67
pyspark Documentation, Release master
Examples
>>> df.write.mode("append").save(os.path.join(tempfile.mkdtemp(), 'data'))
pyspark.sql.DataFrameWriter.saveAsTable
DataFrameWriter.saveAsTable(name, format=None, mode=None, partitionBy=None, **options)Saves the content of the DataFrame as the specified table.
In the case the table already exists, behavior of this function depends on the save mode, specified by the modefunction (default to throwing an exception). When mode is Overwrite, the schema of the DataFrame does notneed to be the same as that of the existing table.
• append: Append contents of this DataFrame to existing data.
• overwrite: Overwrite existing data.
• error or errorifexists: Throw an exception if data already exists.
• ignore: Silently ignore this operation if data already exists.
New in version 1.4.0.
Parameters
name [str] the table name
format [str, optional] the format used to save
mode [str, optional] one of append, overwrite, error, errorifexists, ignore (default: error)
partitionBy [str or list] names of partitioning columns
**options [dict] all other string options
pyspark.sql.DataFrameWriter.sortBy
DataFrameWriter.sortBy(col, *cols)Sorts the output in each bucket by the given columns on the file system.
New in version 2.3.0.
Parameters
col [str, tuple or list] a name of a column, or a list of names.
cols [str] additional names (optional). If col is a list it should be empty.
Examples
>>> (df.write.format('parquet')... .bucketBy(100, 'year', 'month')... .sortBy('day')... .mode("overwrite")... .saveAsTable('sorted_bucketed_table'))
68 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrameWriter.text
DataFrameWriter.text(path, compression=None, lineSep=None)Saves the content of the DataFrame in a text file at the specified path. The text files will be encoded as UTF-8.
New in version 1.6.0.
Parameters
path [str] the path in any Hadoop supported file system
compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate).
lineSep [str, optional] defines the line separator that should be used for writing. If None is set,it uses the default value, \n.
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
3.1.4 DataFrame APIs
DataFrame.agg(*exprs) Aggregate on the entire DataFrame without groups(shorthand for df.groupBy().agg()).
DataFrame.alias(alias) Returns a new DataFrame with an alias set.DataFrame.approxQuantile(col, probabilities,. . . )
Calculates the approximate quantiles of numericalcolumns of a DataFrame.
DataFrame.cache() Persists the DataFrame with the default storage level(MEMORY_AND_DISK).
DataFrame.checkpoint([eager]) Returns a checkpointed version of this Dataset.DataFrame.coalesce(numPartitions) Returns a new DataFrame that has exactly numParti-
tions partitions.DataFrame.colRegex(colName) Selects column based on the column name specified as
a regex and returns it as Column.DataFrame.collect() Returns all the records as a list of Row .DataFrame.columns Returns all column names as a list.DataFrame.corr(col1, col2[, method]) Calculates the correlation of two columns of a
DataFrame as a double value.DataFrame.count() Returns the number of rows in this DataFrame.DataFrame.cov(col1, col2) Calculate the sample covariance for the given columns,
specified by their names, as a double value.DataFrame.createGlobalTempView(name) Creates a global temporary view with this DataFrame.DataFrame.createOrReplaceGlobalTempView(name)Creates or replaces a global temporary view using the
given name.DataFrame.createOrReplaceTempView(name) Creates or replaces a local temporary view with this
DataFrame.DataFrame.createTempView(name) Creates a local temporary view with this DataFrame.DataFrame.crossJoin(other) Returns the cartesian product with another
DataFrame.DataFrame.crosstab(col1, col2) Computes a pair-wise frequency table of the given
columns.continues on next page
3.1. Spark SQL 69
pyspark Documentation, Release master
Table 16 – continued from previous pageDataFrame.cube(*cols) Create a multi-dimensional cube for the current
DataFrame using the specified columns, so we canrun aggregations on them.
DataFrame.describe(*cols) Computes basic statistics for numeric and stringcolumns.
DataFrame.distinct() Returns a new DataFrame containing the distinctrows in this DataFrame.
DataFrame.drop(*cols) Returns a new DataFrame that drops the specified col-umn.
DataFrame.dropDuplicates([subset]) Return a new DataFrame with duplicate rows re-moved, optionally only considering certain columns.
DataFrame.drop_duplicates([subset]) drop_duplicates() is an alias fordropDuplicates().
DataFrame.dropna([how, thresh, subset]) Returns a new DataFrame omitting rows with nullvalues.
DataFrame.dtypes Returns all column names and their data types as a list.DataFrame.exceptAll(other) Return a new DataFrame containing rows in this
DataFrame but not in another DataFrame whilepreserving duplicates.
DataFrame.explain([extended, mode]) Prints the (logical and physical) plans to the console fordebugging purpose.
DataFrame.fillna(value[, subset]) Replace null values, alias for na.fill().DataFrame.filter(condition) Filters rows using the given condition.DataFrame.first() Returns the first row as a Row .DataFrame.foreach(f) Applies the f function to all Row of this DataFrame.DataFrame.foreachPartition(f) Applies the f function to each partition of this
DataFrame.DataFrame.freqItems(cols[, support]) Finding frequent items for columns, possibly with false
positives.DataFrame.groupBy(*cols) Groups the DataFrame using the specified columns,
so we can run aggregation on them.DataFrame.head([n]) Returns the first n rows.DataFrame.hint(name, *parameters) Specifies some hint on the current DataFrame.DataFrame.inputFiles() Returns a best-effort snapshot of the files that compose
this DataFrame.DataFrame.intersect(other) Return a new DataFrame containing rows only in both
this DataFrame and another DataFrame.DataFrame.intersectAll(other) Return a new DataFrame containing rows in both this
DataFrame and another DataFrame while preserv-ing duplicates.
DataFrame.isLocal() Returns True if the collect() and take() meth-ods can be run locally (without any Spark executors).
DataFrame.isStreaming Returns True if this Dataset contains one or moresources that continuously return data as it arrives.
DataFrame.join(other[, on, how]) Joins with another DataFrame, using the given joinexpression.
DataFrame.limit(num) Limits the result count to the number specified.DataFrame.localCheckpoint([eager]) Returns a locally checkpointed version of this Dataset.
continues on next page
70 Chapter 3. API Reference
pyspark Documentation, Release master
Table 16 – continued from previous pageDataFrame.mapInPandas(func, schema) Maps an iterator of batches in the current DataFrame
using a Python native function that takes and out-puts a pandas DataFrame, and returns the result as aDataFrame.
DataFrame.na Returns a DataFrameNaFunctions for handlingmissing values.
DataFrame.orderBy(*cols, **kwargs) Returns a new DataFrame sorted by the specified col-umn(s).
DataFrame.persist([storageLevel]) Sets the storage level to persist the contents of theDataFrame across operations after the first time it iscomputed.
DataFrame.printSchema() Prints out the schema in the tree format.DataFrame.randomSplit(weights[, seed]) Randomly splits this DataFrame with the provided
weights.DataFrame.rdd Returns the content as an pyspark.RDD of Row .DataFrame.registerTempTable(name) Registers this DataFrame as a temporary table using the
given name.DataFrame.repartition(numPartitions, *cols) Returns a new DataFrame partitioned by the given
partitioning expressions.DataFrame.repartitionByRange(numPartitions,. . . )
Returns a new DataFrame partitioned by the givenpartitioning expressions.
DataFrame.replace(to_replace[, value, subset]) Returns a new DataFrame replacing a value with an-other value.
DataFrame.rollup(*cols) Create a multi-dimensional rollup for the currentDataFrame using the specified columns, so we canrun aggregation on them.
DataFrame.sameSemantics(other) Returns True when the logical query plans inside bothDataFrames are equal and therefore return same re-sults.
DataFrame.sample([withReplacement, . . . ]) Returns a sampled subset of this DataFrame.DataFrame.sampleBy(col, fractions[, seed]) Returns a stratified sample without replacement based
on the fraction given on each stratum.DataFrame.schema Returns the schema of this DataFrame as a
pyspark.sql.types.StructType.DataFrame.select(*cols) Projects a set of expressions and returns a new
DataFrame.DataFrame.selectExpr(*expr) Projects a set of SQL expressions and returns a new
DataFrame.DataFrame.semanticHash() Returns a hash code of the logical query plan against
this DataFrame.DataFrame.show([n, truncate, vertical]) Prints the first n rows to the console.DataFrame.sort(*cols, **kwargs) Returns a new DataFrame sorted by the specified col-
umn(s).DataFrame.sortWithinPartitions(*cols,**kwargs)
Returns a new DataFrame with each partition sortedby the specified column(s).
DataFrame.stat Returns a DataFrameStatFunctions for statisticfunctions.
DataFrame.storageLevel Get the DataFrame’s current storage level.DataFrame.subtract(other) Return a new DataFrame containing rows in this
DataFrame but not in another DataFrame.continues on next page
3.1. Spark SQL 71
pyspark Documentation, Release master
Table 16 – continued from previous pageDataFrame.summary(*statistics) Computes specified statistics for numeric and string
columns.DataFrame.tail(num) Returns the last num rows as a list of Row .DataFrame.take(num) Returns the first num rows as a list of Row .DataFrame.toDF(*cols) Returns a new DataFrame that with new specified col-
umn namesDataFrame.toJSON ([use_unicode]) Converts a DataFrame into a RDD of string.DataFrame.toLocalIterator([prefetchPartitions])Returns an iterator that contains all of the rows in this
DataFrame.DataFrame.toPandas() Returns the contents of this DataFrame as Pandas
pandas.DataFrame.DataFrame.transform(func) Returns a new DataFrame.DataFrame.union(other) Return a new DataFrame containing union of rows in
this and another DataFrame.DataFrame.unionAll(other) Return a new DataFrame containing union of rows in
this and another DataFrame.DataFrame.unionByName(other[, . . . ]) Returns a new DataFrame containing union of rows
in this and another DataFrame.DataFrame.unpersist([blocking]) Marks the DataFrame as non-persistent, and remove
all blocks for it from memory and disk.DataFrame.where(condition) where() is an alias for filter().DataFrame.withColumn(colName, col) Returns a new DataFrame by adding a column or re-
placing the existing column that has the same name.DataFrame.withColumnRenamed(existing, new) Returns a new DataFrame by renaming an existing
column.DataFrame.withWatermark(eventTime, . . . ) Defines an event time watermark for this DataFrame.DataFrame.write Interface for saving the content of the non-streaming
DataFrame out into external storage.DataFrame.writeStream Interface for saving the content of the streaming
DataFrame out into external storage.DataFrame.writeTo(table) Create a write configuration builder for v2 sources.DataFrameNaFunctions.drop([how, thresh,subset])
Returns a new DataFrame omitting rows with nullvalues.
DataFrameNaFunctions.fill(value[, subset]) Replace null values, alias for na.fill().DataFrameNaFunctions.replace(to_replace[,. . . ])
Returns a new DataFrame replacing a value with an-other value.
DataFrameStatFunctions.approxQuantile(col, . . . )
Calculates the approximate quantiles of numericalcolumns of a DataFrame.
DataFrameStatFunctions.corr(col1, col2[,method])
Calculates the correlation of two columns of aDataFrame as a double value.
DataFrameStatFunctions.cov(col1, col2) Calculate the sample covariance for the given columns,specified by their names, as a double value.
DataFrameStatFunctions.crosstab(col1,col2)
Computes a pair-wise frequency table of the givencolumns.
DataFrameStatFunctions.freqItems(cols[,support])
Finding frequent items for columns, possibly with falsepositives.
DataFrameStatFunctions.sampleBy(col,fractions)
Returns a stratified sample without replacement basedon the fraction given on each stratum.
72 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame.agg
DataFrame.agg(*exprs)Aggregate on the entire DataFrame without groups (shorthand for df.groupBy().agg()).
New in version 1.3.0.
Examples
>>> df.agg({"age": "max"}).collect()[Row(max(age)=5)]>>> from pyspark.sql import functions as F>>> df.agg(F.min(df.age)).collect()[Row(min(age)=2)]
pyspark.sql.DataFrame.alias
DataFrame.alias(alias)Returns a new DataFrame with an alias set.
New in version 1.3.0.
Parameters
alias [str] an alias name to be set for the DataFrame.
Examples
>>> from pyspark.sql.functions import *>>> df_as1 = df.alias("df_as1")>>> df_as2 = df.alias("df_as2")>>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"),→˓'inner')>>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age") .→˓sort(desc("df_as1.name")).collect()[Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice', age=2)]
pyspark.sql.DataFrame.approxQuantile
DataFrame.approxQuantile(col, probabilities, relativeError)Calculates the approximate quantiles of numerical columns of a DataFrame.
The result of this algorithm has the following deterministic bound: If the DataFrame has N elements andif we request the quantile at probability p up to error err, then the algorithm will return a sample x from theDataFrame so that the exact rank of x is close to (p * N). More precisely,
floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). Thealgorithm was first present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation ofQuantile Summaries]] by Greenwald and Khanna.
Note that null values will be ignored in numerical columns before calculation. For columns only containing nullvalues, an empty list is returned.
3.1. Spark SQL 73
pyspark Documentation, Release master
New in version 2.0.0.
Parameters
col: str, tuple or list Can be a single column name, or a list of names for multiple columns.
Changed in version 2.2: Added support for multiple columns.
probabilities [list or tuple] a list of quantile probabilities Each number must belong to [0, 1].For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
relativeError [float] The relative target precision to achieve (>= 0). If set to zero, the exactquantiles are computed, which could be very expensive. Note that values greater than 1 areaccepted but give the same result as 1.
Returns
list the approximate quantiles at the given probabilities. If the input col is a string, the output isa list of floats. If the input col is a list or tuple of strings, the output is also a list, but eachelement in it is a list of floats, i.e., the output is a list of list of floats.
pyspark.sql.DataFrame.cache
DataFrame.cache()Persists the DataFrame with the default storage level (MEMORY_AND_DISK).
New in version 1.3.0.
Notes
The default storage level has changed to MEMORY_AND_DISK to match Scala in 2.0.
pyspark.sql.DataFrame.checkpoint
DataFrame.checkpoint(eager=True)Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the logical plan of thisDataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially. It willbe saved to files inside the checkpoint directory set with SparkContext.setCheckpointDir().
New in version 2.1.0.
Parameters
eager [bool, optional] Whether to checkpoint this DataFrame immediately
Notes
This API is experimental.
74 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame.coalesce
DataFrame.coalesce(numPartitions)Returns a new DataFrame that has exactly numPartitions partitions.
Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g. if you go from 1000partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of thecurrent partitions. If a larger number of partitions is requested, it will stay at the current number of partitions.
However, if you’re doing a drastic coalesce, e.g. to numPartitions = 1, this may result in your computationtaking place on fewer nodes than you like (e.g. one node in the case of numPartitions = 1). To avoid this, youcan call repartition(). This will add a shuffle step, but means the current upstream partitions will be executed inparallel (per whatever the current partitioning is).
New in version 1.4.0.
Parameters
numPartitions [int] specify the target number of partitions
Examples
>>> df.coalesce(1).rdd.getNumPartitions()1
pyspark.sql.DataFrame.colRegex
DataFrame.colRegex(colName)Selects column based on the column name specified as a regex and returns it as Column.
New in version 2.3.0.
Parameters
colName [str] string, column name specified as a regex.
Examples
>>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"])>>> df.select(df.colRegex("`(Col1)?+.+`")).show()+----+|Col2|+----+| 1|| 2|| 3|+----+
3.1. Spark SQL 75
pyspark Documentation, Release master
pyspark.sql.DataFrame.collect
DataFrame.collect()Returns all the records as a list of Row .
New in version 1.3.0.
Examples
>>> df.collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]
pyspark.sql.DataFrame.columns
property DataFrame.columnsReturns all column names as a list.
New in version 1.3.0.
Examples
>>> df.columns['age', 'name']
pyspark.sql.DataFrame.corr
DataFrame.corr(col1, col2, method=None)Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports thePearson Correlation Coefficient. DataFrame.corr() and DataFrameStatFunctions.corr() arealiases of each other.
New in version 1.4.0.
Parameters
col1 [str] The name of the first column
col2 [str] The name of the second column
method [str, optional] The correlation method. Currently only supports “pearson”
pyspark.sql.DataFrame.count
DataFrame.count()Returns the number of rows in this DataFrame.
New in version 1.3.0.
76 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.count()2
pyspark.sql.DataFrame.cov
DataFrame.cov(col1, col2)Calculate the sample covariance for the given columns, specified by their names, as a double value.DataFrame.cov() and DataFrameStatFunctions.cov() are aliases.
New in version 1.4.0.
Parameters
col1 [str] The name of the first column
col2 [str] The name of the second column
pyspark.sql.DataFrame.createGlobalTempView
DataFrame.createGlobalTempView(name)Creates a global temporary view with this DataFrame.
The lifetime of this temporary view is tied to this Spark application. throwsTempTableAlreadyExistsException, if the view name already exists in the catalog.
New in version 2.1.0.
Examples
>>> df.createGlobalTempView("people")>>> df2 = spark.sql("select * from global_temp.people")>>> sorted(df.collect()) == sorted(df2.collect())True>>> df.createGlobalTempView("people")Traceback (most recent call last):...AnalysisException: u"Temporary table 'people' already exists;">>> spark.catalog.dropGlobalTempView("people")
pyspark.sql.DataFrame.createOrReplaceGlobalTempView
DataFrame.createOrReplaceGlobalTempView(name)Creates or replaces a global temporary view using the given name.
The lifetime of this temporary view is tied to this Spark application.
New in version 2.2.0.
3.1. Spark SQL 77
pyspark Documentation, Release master
Examples
>>> df.createOrReplaceGlobalTempView("people")>>> df2 = df.filter(df.age > 3)>>> df2.createOrReplaceGlobalTempView("people")>>> df3 = spark.sql("select * from global_temp.people")>>> sorted(df3.collect()) == sorted(df2.collect())True>>> spark.catalog.dropGlobalTempView("people")
pyspark.sql.DataFrame.createOrReplaceTempView
DataFrame.createOrReplaceTempView(name)Creates or replaces a local temporary view with this DataFrame.
The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.
New in version 2.0.0.
Examples
>>> df.createOrReplaceTempView("people")>>> df2 = df.filter(df.age > 3)>>> df2.createOrReplaceTempView("people")>>> df3 = spark.sql("select * from people")>>> sorted(df3.collect()) == sorted(df2.collect())True>>> spark.catalog.dropTempView("people")
pyspark.sql.DataFrame.createTempView
DataFrame.createTempView(name)Creates a local temporary view with this DataFrame.
The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.throws TempTableAlreadyExistsException, if the view name already exists in the catalog.
New in version 2.0.0.
Examples
>>> df.createTempView("people")>>> df2 = spark.sql("select * from people")>>> sorted(df.collect()) == sorted(df2.collect())True>>> df.createTempView("people")Traceback (most recent call last):...AnalysisException: u"Temporary table 'people' already exists;">>> spark.catalog.dropTempView("people")
78 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame.crossJoin
DataFrame.crossJoin(other)Returns the cartesian product with another DataFrame.
New in version 2.1.0.
Parameters
other [DataFrame] Right side of the cartesian product.
Examples
>>> df.select("age", "name").collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df2.select("name", "height").collect()[Row(name='Tom', height=80), Row(name='Bob', height=85)]>>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect()[Row(age=2, name='Alice', height=80), Row(age=2, name='Alice', height=85),Row(age=5, name='Bob', height=80), Row(age=5, name='Bob', height=85)]
pyspark.sql.DataFrame.crosstab
DataFrame.crosstab(col1, col2)Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number ofdistinct values for each column should be less than 1e4. At most 1e6 non-zero pair frequencies will be returned.The first column of each row will be the distinct values of col1 and the column names will be the distinct valuesof col2. The name of the first column will be $col1_$col2. Pairs that have no occurrences will have zero as theircounts. DataFrame.crosstab() and DataFrameStatFunctions.crosstab() are aliases.
New in version 1.4.0.
Parameters
col1 [str] The name of the first column. Distinct items will make the first item of each row.
col2 [str] The name of the second column. Distinct items will make the column names of theDataFrame.
pyspark.sql.DataFrame.cube
DataFrame.cube(*cols)Create a multi-dimensional cube for the current DataFrame using the specified columns, so we can run aggre-gations on them.
New in version 1.4.0.
3.1. Spark SQL 79
pyspark Documentation, Release master
Examples
>>> df.cube("name", df.age).count().orderBy("name", "age").show()+-----+----+-----+| name| age|count|+-----+----+-----+| null|null| 2|| null| 2| 1|| null| 5| 1||Alice|null| 1||Alice| 2| 1|| Bob|null| 1|| Bob| 5| 1|+-----+----+-----+
pyspark.sql.DataFrame.describe
DataFrame.describe(*cols)Computes basic statistics for numeric and string columns.
New in version 1.3.1.
This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics forall numerical or string columns.
See also:
DataFrame.summary
Notes
This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.
Use summary for expanded statistics and control over which statistics to compute.
Examples
>>> df.describe(['age']).show()+-------+------------------+|summary| age|+-------+------------------+| count| 2|| mean| 3.5|| stddev|2.1213203435596424|| min| 2|| max| 5|+-------+------------------+>>> df.describe().show()+-------+------------------+-----+|summary| age| name|+-------+------------------+-----+| count| 2| 2|| mean| 3.5| null|
(continues on next page)
80 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
| stddev|2.1213203435596424| null|| min| 2|Alice|| max| 5| Bob|+-------+------------------+-----+
pyspark.sql.DataFrame.distinct
DataFrame.distinct()Returns a new DataFrame containing the distinct rows in this DataFrame.
New in version 1.3.0.
Examples
>>> df.distinct().count()2
pyspark.sql.DataFrame.drop
DataFrame.drop(*cols)Returns a new DataFrame that drops the specified column. This is a no-op if schema doesn’t contain the givencolumn name(s).
New in version 1.4.0.
Parameters
cols: str or :class:`Column` a name of the column, or the Column to drop
Examples
>>> df.drop('age').collect()[Row(name='Alice'), Row(name='Bob')]
>>> df.drop(df.age).collect()[Row(name='Alice'), Row(name='Bob')]
>>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()[Row(age=5, height=85, name='Bob')]
>>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()[Row(age=5, name='Bob', height=85)]
>>> df.join(df2, 'name', 'inner').drop('age', 'height').collect()[Row(name='Bob')]
3.1. Spark SQL 81
pyspark Documentation, Release master
pyspark.sql.DataFrame.dropDuplicates
DataFrame.dropDuplicates(subset=None)Return a new DataFrame with duplicate rows removed, optionally only considering certain columns.
For a static batch DataFrame, it just drops duplicate rows. For a streaming DataFrame, it will keep all dataacross triggers as intermediate state to drop duplicates rows. You can use withWatermark() to limit howlate the duplicate data can be and system will accordingly limit the state. In addition, too late data older thanwatermark will be dropped to avoid any possibility of duplicates.
drop_duplicates() is an alias for dropDuplicates().
New in version 1.4.0.
Examples
>>> from pyspark.sql import Row>>> df = sc.parallelize([ \... Row(name='Alice', age=5, height=80), \... Row(name='Alice', age=5, height=80), \... Row(name='Alice', age=10, height=80)]).toDF()>>> df.dropDuplicates().show()+-----+---+------+| name|age|height|+-----+---+------+|Alice| 5| 80||Alice| 10| 80|+-----+---+------+
>>> df.dropDuplicates(['name', 'height']).show()+-----+---+------+| name|age|height|+-----+---+------+|Alice| 5| 80|+-----+---+------+
pyspark.sql.DataFrame.drop_duplicates
DataFrame.drop_duplicates(subset=None)drop_duplicates() is an alias for dropDuplicates().
New in version 1.4.
pyspark.sql.DataFrame.dropna
DataFrame.dropna(how='any', thresh=None, subset=None)Returns a new DataFrame omitting rows with null values. DataFrame.dropna() andDataFrameNaFunctions.drop() are aliases of each other.
New in version 1.3.1.
Parameters
how [str, optional] ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a rowonly if all its values are null.
82 Chapter 3. API Reference
pyspark Documentation, Release master
thresh: int, optional default None If specified, drop rows that have less than thresh non-nullvalues. This overwrites the how parameter.
subset [str, tuple or list, optional] optional list of column names to consider.
Examples
>>> df4.na.drop().show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|+---+------+-----+
pyspark.sql.DataFrame.dtypes
property DataFrame.dtypesReturns all column names and their data types as a list.
New in version 1.3.0.
Examples
>>> df.dtypes[('age', 'int'), ('name', 'string')]
pyspark.sql.DataFrame.exceptAll
DataFrame.exceptAll(other)Return a new DataFrame containing rows in this DataFrame but not in another DataFrame while preserv-ing duplicates.
This is equivalent to EXCEPT ALL in SQL. As standard in SQL, this function resolves columns by position (notby name).
New in version 2.4.0.
Examples
>>> df1 = spark.createDataFrame(... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1",→˓"C2"])>>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"])
>>> df1.exceptAll(df2).show()+---+---+| C1| C2|+---+---+| a| 1|| a| 1|| a| 2|
(continues on next page)
3.1. Spark SQL 83
pyspark Documentation, Release master
(continued from previous page)
| c| 4|+---+---+
pyspark.sql.DataFrame.explain
DataFrame.explain(extended=None, mode=None)Prints the (logical and physical) plans to the console for debugging purpose.
New in version 1.3.0.
Parameters
extended [bool, optional] default False. If False, prints only the physical plan. When thisis a string without specifying the mode, it works as the mode is specified.
mode [str, optional] specifies the expected output format of plans.
• simple: Print only a physical plan.
• extended: Print both logical and physical plans.
• codegen: Print a physical plan and generated codes if they are available.
• cost: Print a logical plan and statistics if they are available.
• formatted: Split explain output into two sections: a physical plan outline and nodedetails.
Changed in version 3.0.0: Added optional argument mode to specify the expected outputformat of plans.
Examples
>>> df.explain()== Physical Plan ==
*(1) Scan ExistingRDD[age#0,name#1]
>>> df.explain(True)== Parsed Logical Plan ==...== Analyzed Logical Plan ==...== Optimized Logical Plan ==...== Physical Plan ==...
>>> df.explain(mode="formatted")== Physical Plan ==
* Scan ExistingRDD (1)(1) Scan ExistingRDD [codegen id : 1]Output [2]: [age#0, name#1]...
84 Chapter 3. API Reference
pyspark Documentation, Release master
>>> df.explain("cost")== Optimized Logical Plan ==...Statistics......
pyspark.sql.DataFrame.fillna
DataFrame.fillna(value, subset=None)Replace null values, alias for na.fill(). DataFrame.fillna() and DataFrameNaFunctions.fill() are aliases of each other.
New in version 1.3.1.
Parameters
value [int, float, string, bool or dict] Value to replace null values with. If the value is a dict, thensubset is ignored and value must be a mapping from column name (string) to replacementvalue. The replacement value must be an int, float, boolean, or string.
subset [str, tuple or list, optional] optional list of column names to consider. Columns specifiedin subset that do not have matching data type are ignored. For example, if value is a string,and subset contains a non-string column, then the non-string column is simply ignored.
Examples
>>> df4.na.fill(50).show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|| 5| 50| Bob|| 50| 50| Tom|| 50| 50| null|+---+------+-----+
>>> df5.na.fill(False).show()+----+-------+-----+| age| name| spy|+----+-------+-----+| 10| Alice|false|| 5| Bob|false||null|Mallory| true|+----+-------+-----+
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()+---+------+-------+|age|height| name|+---+------+-------+| 10| 80| Alice|| 5| null| Bob|| 50| null| Tom|| 50| null|unknown|+---+------+-------+
3.1. Spark SQL 85
pyspark Documentation, Release master
pyspark.sql.DataFrame.filter
DataFrame.filter(condition)Filters rows using the given condition.
where() is an alias for filter().
New in version 1.3.0.
Parameters
condition [Column or str] a Column of types.BooleanType or a string of SQL expres-sion.
Examples
>>> df.filter(df.age > 3).collect()[Row(age=5, name='Bob')]>>> df.where(df.age == 2).collect()[Row(age=2, name='Alice')]
>>> df.filter("age > 3").collect()[Row(age=5, name='Bob')]>>> df.where("age = 2").collect()[Row(age=2, name='Alice')]
pyspark.sql.DataFrame.first
DataFrame.first()Returns the first row as a Row .
New in version 1.3.0.
Examples
>>> df.first()Row(age=2, name='Alice')
pyspark.sql.DataFrame.foreach
DataFrame.foreach(f)Applies the f function to all Row of this DataFrame.
This is a shorthand for df.rdd.foreach().
New in version 1.3.0.
86 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> def f(person):... print(person.name)>>> df.foreach(f)
pyspark.sql.DataFrame.foreachPartition
DataFrame.foreachPartition(f)Applies the f function to each partition of this DataFrame.
This a shorthand for df.rdd.foreachPartition().
New in version 1.3.0.
Examples
>>> def f(people):... for person in people:... print(person.name)>>> df.foreachPartition(f)
pyspark.sql.DataFrame.freqItems
DataFrame.freqItems(cols, support=None)Finding frequent items for columns, possibly with false positives. Using the frequent element count algo-rithm described in “https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou”.DataFrame.freqItems() and DataFrameStatFunctions.freqItems() are aliases.
New in version 1.4.0.
Parameters
cols [list or tuple] Names of the columns to calculate frequent items for as a list or tuple ofstrings.
support [float, optional] The frequency with which to consider an item ‘frequent’. Default is1%. The support must be greater than 1e-4.
Notes
This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.
3.1. Spark SQL 87
pyspark Documentation, Release master
pyspark.sql.DataFrame.groupBy
DataFrame.groupBy(*cols)Groups the DataFrame using the specified columns, so we can run aggregation on them. See GroupedDatafor all the available aggregate functions.
groupby() is an alias for groupBy().
New in version 1.3.0.
Parameters
cols [list, str or Column] columns to group by. Each element should be a column name (string)or an expression (Column).
Examples
>>> df.groupBy().avg().collect()[Row(avg(age)=3.5)]>>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect())[Row(name='Alice', avg(age)=2.0), Row(name='Bob', avg(age)=5.0)]>>> sorted(df.groupBy(df.name).avg().collect())[Row(name='Alice', avg(age)=2.0), Row(name='Bob', avg(age)=5.0)]>>> sorted(df.groupBy(['name', df.age]).count().collect())[Row(name='Alice', age=2, count=1), Row(name='Bob', age=5, count=1)]
pyspark.sql.DataFrame.head
DataFrame.head(n=None)Returns the first n rows.
New in version 1.3.0.
Parameters
n [int, optional] default 1. Number of rows to return.
Returns
If n is greater than 1, return a list of Row .
If n is 1, return a single Row.
Notes
This method should only be used if the resulting array is expected to be small, as all the data is loaded into thedriver’s memory.
88 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.head()Row(age=2, name='Alice')>>> df.head(1)[Row(age=2, name='Alice')]
pyspark.sql.DataFrame.hint
DataFrame.hint(name, *parameters)Specifies some hint on the current DataFrame.
New in version 2.2.0.
Parameters
name [str] A name of the hint.
parameters [str, list, float or int] Optional parameters.
Returns
DataFrame
Examples
>>> df.join(df2.hint("broadcast"), "name").show()+----+---+------+|name|age|height|+----+---+------+| Bob| 5| 85|+----+---+------+
pyspark.sql.DataFrame.inputFiles
DataFrame.inputFiles()Returns a best-effort snapshot of the files that compose this DataFrame. This method simply asks eachconstituent BaseRelation for its respective files and takes the union of all results. Depending on the sourcerelations, this may not find all input files. Duplicates are removed.
New in version 3.1.0.
Examples
>>> df = spark.read.load("examples/src/main/resources/people.json", format="json")>>> len(df.inputFiles())1
3.1. Spark SQL 89
pyspark Documentation, Release master
pyspark.sql.DataFrame.intersect
DataFrame.intersect(other)Return a new DataFrame containing rows only in both this DataFrame and another DataFrame.
This is equivalent to INTERSECT in SQL.
New in version 1.3.
pyspark.sql.DataFrame.intersectAll
DataFrame.intersectAll(other)Return a new DataFrame containing rows in both this DataFrame and another DataFrame while preserv-ing duplicates.
This is equivalent to INTERSECT ALL in SQL. As standard in SQL, this function resolves columns by position(not by name).
New in version 2.4.0.
Examples
>>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1",→˓"C2"])>>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"])
>>> df1.intersectAll(df2).sort("C1", "C2").show()+---+---+| C1| C2|+---+---+| a| 1|| a| 1|| b| 3|+---+---+
pyspark.sql.DataFrame.isLocal
DataFrame.isLocal()Returns True if the collect() and take() methods can be run locally (without any Spark executors).
New in version 1.3.
pyspark.sql.DataFrame.isStreaming
property DataFrame.isStreamingReturns True if this Dataset contains one or more sources that continuously return data as it arrives.A Dataset that reads data from a streaming source must be executed as a StreamingQuery usingthe start() method in DataStreamWriter. Methods that return a single answer, (e.g., count() orcollect()) will throw an AnalysisException when there is a streaming source present.
New in version 2.0.0.
90 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This API is evolving.
pyspark.sql.DataFrame.join
DataFrame.join(other, on=None, how=None)Joins with another DataFrame, using the given join expression.
New in version 1.3.0.
Parameters
other [DataFrame] Right side of the join
on [str, list or Column, optional] a string for the join column name, a list of column names, ajoin expression (Column), or a list of Columns. If on is a string or a list of strings indicatingthe name of the join column(s), the column(s) must exist on both sides, and this performsan equi-join.
how [str, optional] default inner. Must be one of: inner, cross, outer,full, fullouter, full_outer, left, leftouter, left_outer, right,rightouter, right_outer, semi, leftsemi, left_semi, anti, leftantiand left_anti.
Examples
The following performs a full outer join between df1 and df2. >>> from pyspark.sql.functions importdesc >>> df.join(df2, df.name == df2.name, ‘outer’).select(df.name, df2.height) .sort(desc(“name”)).collect()[Row(name=’Bob’, height=85), Row(name=’Alice’, height=None), Row(name=None, height=80)]
>>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).→˓collect()[Row(name='Tom', height=80), Row(name='Bob', height=85), Row(name='Alice',→˓height=None)]
>>> cond = [df.name == df3.name, df.age == df3.age]>>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()[Row(name='Alice', age=2), Row(name='Bob', age=5)]
>>> df.join(df2, 'name').select(df.name, df2.height).collect()[Row(name='Bob', height=85)]
>>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()[Row(name='Bob', age=5)]
3.1. Spark SQL 91
pyspark Documentation, Release master
pyspark.sql.DataFrame.limit
DataFrame.limit(num)Limits the result count to the number specified.
New in version 1.3.0.
Examples
>>> df.limit(1).collect()[Row(age=2, name='Alice')]>>> df.limit(0).collect()[]
pyspark.sql.DataFrame.localCheckpoint
DataFrame.localCheckpoint(eager=True)Returns a locally checkpointed version of this Dataset. Checkpointing can be used to truncate the logical planof this DataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially.Local checkpoints are stored in the executors using the caching subsystem and therefore they are not reliable.
New in version 2.3.0.
Parameters
eager [bool, optional] Whether to checkpoint this DataFrame immediately
Notes
This API is experimental.
pyspark.sql.DataFrame.mapInPandas
DataFrame.mapInPandas(func, schema)Maps an iterator of batches in the current DataFrame using a Python native function that takes and outputs apandas DataFrame, and returns the result as a DataFrame.
The function should take an iterator of pandas.DataFrames and return another iterator of pandas.DataFrames.All columns are passed together as an iterator of pandas.DataFrames to the function and the returned iteratorof pandas.DataFrames are combined as a DataFrame. Each pandas.DataFrame size can be controlled byspark.sql.execution.arrow.maxRecordsPerBatch.
New in version 3.0.0.
Parameters
func [function] a Python native function that takes an iterator of pandas.DataFrames, and out-puts an iterator of pandas.DataFrames.
schema [pyspark.sql.types.DataType or str] the return type of the func in PySpark.The value can be either a pyspark.sql.types.DataType object or a DDL-formattedtype string.
See also:
pyspark.sql.functions.pandas_udf
92 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This API is experimental
Examples
>>> from pyspark.sql.functions import pandas_udf>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))>>> def filter_func(iterator):... for pdf in iterator:... yield pdf[pdf.id == 1]>>> df.mapInPandas(filter_func, df.schema).show()+---+---+| id|age|+---+---+| 1| 21|+---+---+
pyspark.sql.DataFrame.na
property DataFrame.naReturns a DataFrameNaFunctions for handling missing values.
New in version 1.3.1.
pyspark.sql.DataFrame.orderBy
DataFrame.orderBy(*cols, **kwargs)Returns a new DataFrame sorted by the specified column(s).
New in version 1.3.0.
Parameters
cols [str, list, or Column, optional] list of Column or column names to sort by.
Other Parameters
ascending [bool or list, optional] boolean or list of boolean (default True). Sort ascending vs.descending. Specify list for multiple sort orders. If a list is specified, length of the list mustequal length of the cols.
Examples
>>> df.sort(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.sort("age", ascending=False).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.orderBy(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> from pyspark.sql.functions import *>>> df.sort(asc("age")).collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df.orderBy(desc("age"), "name").collect()
(continues on next page)
3.1. Spark SQL 93
pyspark Documentation, Release master
(continued from previous page)
[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]
pyspark.sql.DataFrame.persist
DataFrame.persist(storageLevel=StorageLevel(True, True, False, True, 1))Sets the storage level to persist the contents of the DataFrame across operations after the first time it iscomputed. This can only be used to assign a new storage level if the DataFrame does not have a storage levelset yet. If no storage level is specified defaults to (MEMORY_AND_DISK_DESER)
New in version 1.3.0.
Notes
The default storage level has changed to MEMORY_AND_DISK_DESER to match Scala in 3.0.
pyspark.sql.DataFrame.printSchema
DataFrame.printSchema()Prints out the schema in the tree format.
New in version 1.3.0.
Examples
>>> df.printSchema()root|-- age: integer (nullable = true)|-- name: string (nullable = true)
pyspark.sql.DataFrame.randomSplit
DataFrame.randomSplit(weights, seed=None)Randomly splits this DataFrame with the provided weights.
New in version 1.4.0.
Parameters
weights [list] list of doubles as weights with which to split the DataFrame. Weights will benormalized if they don’t sum up to 1.0.
seed [int, optional] The seed for sampling.
94 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> splits = df4.randomSplit([1.0, 2.0], 24)>>> splits[0].count()2
>>> splits[1].count()2
pyspark.sql.DataFrame.rdd
property DataFrame.rddReturns the content as an pyspark.RDD of Row .
New in version 1.3.
pyspark.sql.DataFrame.registerTempTable
DataFrame.registerTempTable(name)Registers this DataFrame as a temporary table using the given name.
The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.
New in version 1.3.0.
Deprecated since version 2.0.0: Use DataFrame.createOrReplaceTempView() instead.
Examples
>>> df.registerTempTable("people")>>> df2 = spark.sql("select * from people")>>> sorted(df.collect()) == sorted(df2.collect())True>>> spark.catalog.dropTempView("people")
pyspark.sql.DataFrame.repartition
DataFrame.repartition(numPartitions, *cols)Returns a new DataFrame partitioned by the given partitioning expressions. The resulting DataFrame ishash partitioned.
New in version 1.3.0.
Parameters
numPartitions [int] can be an int to specify the target number of partitions or a Column. Ifit is a Column, it will be used as the first partitioning column. If not specified, the defaultnumber of partitions is used.
cols [str or Column] partitioning columns.
Changed in version 1.6: Added optional arguments to specify the partitioning columns. Alsomade numPartitions optional if partitioning columns are specified.
3.1. Spark SQL 95
pyspark Documentation, Release master
Examples
>>> df.repartition(10).rdd.getNumPartitions()10>>> data = df.union(df).repartition("age")>>> data.show()+---+-----+|age| name|+---+-----+| 5| Bob|| 5| Bob|| 2|Alice|| 2|Alice|+---+-----+>>> data = data.repartition(7, "age")>>> data.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|| 2|Alice|| 5| Bob|+---+-----+>>> data.rdd.getNumPartitions()7>>> data = data.repartition("name", "age")>>> data.show()+---+-----+|age| name|+---+-----+| 5| Bob|| 5| Bob|| 2|Alice|| 2|Alice|+---+-----+
pyspark.sql.DataFrame.repartitionByRange
DataFrame.repartitionByRange(numPartitions, *cols)Returns a new DataFrame partitioned by the given partitioning expressions. The resulting DataFrame isrange partitioned.
At least one partition-by expression must be specified. When no explicit sort order is specified, “ascending nullsfirst” is assumed.
New in version 2.4.0.
Parameters
numPartitions [int] can be an int to specify the target number of partitions or a Column. Ifit is a Column, it will be used as the first partitioning column. If not specified, the defaultnumber of partitions is used.
cols [str or Column] partitioning columns.
96 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
Due to performance reasons this method uses sampling to estimate the ranges. Hence, the output may notbe consistent, since sampling can return different values. The sample size can be controlled by the configspark.sql.execution.rangeExchange.sampleSizePerPartition.
Examples
>>> df.repartitionByRange(2, "age").rdd.getNumPartitions()2>>> df.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+>>> df.repartitionByRange(1, "age").rdd.getNumPartitions()1>>> data = df.repartitionByRange("age")>>> df.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+
pyspark.sql.DataFrame.replace
DataFrame.replace(to_replace, value=<no value>, subset=None)Returns a new DataFrame replacing a value with another value. DataFrame.replace() andDataFrameNaFunctions.replace() are aliases of each other. Values to_replace and value must havethe same type and can only be numerics, booleans, or strings. Value can have None. When replacing, the newvalue will be cast to the type of the existing column. For numeric replacements all values to be replaced shouldhave unique floating point representation. In case of conflicts (for example with {42: -1, 42.0: 1}) and arbitraryreplacement will be used.
New in version 1.4.0.
Parameters
to_replace [bool, int, float, string, list or dict] Value to be replaced. If the value is a dict, thenvalue is ignored or can be omitted, and to_replace must be a mapping between a value anda replacement.
value [bool, int, float, string or None, optional] The replacement value must be a bool, int, float,string or None. If value is a list, value should be of the same length and type as to_replace.If value is a scalar and to_replace is a sequence, then value is used as a replacement for eachitem in to_replace.
subset [list, optional] optional list of column names to consider. Columns specified in subsetthat do not have matching data type are ignored. For example, if value is a string, and subsetcontains a non-string column, then the non-string column is simply ignored.
3.1. Spark SQL 97
pyspark Documentation, Release master
Examples
>>> df4.na.replace(10, 20).show()+----+------+-----+| age|height| name|+----+------+-----+| 20| 80|Alice|| 5| null| Bob||null| null| Tom||null| null| null|+----+------+-----+
>>> df4.na.replace('Alice', None).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+
>>> df4.na.replace({'Alice': None}).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+
>>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()+----+------+----+| age|height|name|+----+------+----+| 10| 80| A|| 5| null| B||null| null| Tom||null| null|null|+----+------+----+
98 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame.rollup
DataFrame.rollup(*cols)Create a multi-dimensional rollup for the current DataFrame using the specified columns, so we can runaggregation on them.
New in version 1.4.0.
Examples
>>> df.rollup("name", df.age).count().orderBy("name", "age").show()+-----+----+-----+| name| age|count|+-----+----+-----+| null|null| 2||Alice|null| 1||Alice| 2| 1|| Bob|null| 1|| Bob| 5| 1|+-----+----+-----+
pyspark.sql.DataFrame.sameSemantics
DataFrame.sameSemantics(other)Returns True when the logical query plans inside both DataFrames are equal and therefore return same results.
New in version 3.1.0.
Notes
The equality comparison here is simplified by tolerating the cosmetic differences such as attribute names.
This API can compare both DataFrames very fast but can still return False on the DataFrame that return thesame results, for instance, from different plans. Such false negative semantic can be useful when caching as anexample.
This API is a developer API.
Examples
>>> df1 = spark.range(10)>>> df2 = spark.range(10)>>> df1.withColumn("col1", df1.id * 2).sameSemantics(df2.withColumn("col1", df2.→˓id * 2))True>>> df1.withColumn("col1", df1.id * 2).sameSemantics(df2.withColumn("col1", df2.→˓id + 2))False>>> df1.withColumn("col1", df1.id * 2).sameSemantics(df2.withColumn("col0", df2.→˓id * 2))True
3.1. Spark SQL 99
pyspark Documentation, Release master
pyspark.sql.DataFrame.sample
DataFrame.sample(withReplacement=None, fraction=None, seed=None)Returns a sampled subset of this DataFrame.
New in version 1.3.0.
Parameters
withReplacement [bool, optional] Sample with replacement or not (default False).
fraction [float, optional] Fraction of rows to generate, range [0.0, 1.0].
seed [int, optional] Seed for sampling (default a random seed).
Notes
This is not guaranteed to provide exactly the fraction specified of the total count of the given DataFrame.
fraction is required and, withReplacement and seed are optional.
Examples
>>> df = spark.range(10)>>> df.sample(0.5, 3).count()7>>> df.sample(fraction=0.5, seed=3).count()7>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()1>>> df.sample(1.0).count()10>>> df.sample(fraction=1.0).count()10>>> df.sample(False, fraction=1.0).count()10
pyspark.sql.DataFrame.sampleBy
DataFrame.sampleBy(col, fractions, seed=None)Returns a stratified sample without replacement based on the fraction given on each stratum.
New in version 1.5.0.
Parameters
col [Column or str] column that defines strata
Changed in version 3.0: Added sampling by a column of Column
fractions [dict] sampling fraction for each stratum. If a stratum is not specified, we treat itsfraction as zero.
seed [int, optional] random seed
Returns
a new DataFrame that represents the stratified sample
100 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.sql.functions import col>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)>>> sampled.groupBy("key").count().orderBy("key").show()+---+-----+|key|count|+---+-----+| 0| 3|| 1| 6|+---+-----+>>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()33
pyspark.sql.DataFrame.schema
property DataFrame.schemaReturns the schema of this DataFrame as a pyspark.sql.types.StructType.
New in version 1.3.0.
Examples
>>> df.schemaStructType(List(StructField(age,IntegerType,true),StructField(name,StringType,→˓true)))
pyspark.sql.DataFrame.select
DataFrame.select(*cols)Projects a set of expressions and returns a new DataFrame.
New in version 1.3.0.
Parameters
cols [str, Column, or list] column names (string) or expressions (Column). If one of thecolumn names is ‘*’, that column is expanded to include all columns in the currentDataFrame.
Examples
>>> df.select('*').collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df.select('name', 'age').collect()[Row(name='Alice', age=2), Row(name='Bob', age=5)]>>> df.select(df.name, (df.age + 10).alias('age')).collect()[Row(name='Alice', age=12), Row(name='Bob', age=15)]
3.1. Spark SQL 101
pyspark Documentation, Release master
pyspark.sql.DataFrame.selectExpr
DataFrame.selectExpr(*expr)Projects a set of SQL expressions and returns a new DataFrame.
This is a variant of select() that accepts SQL expressions.
New in version 1.3.0.
Examples
>>> df.selectExpr("age * 2", "abs(age)").collect()[Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
pyspark.sql.DataFrame.semanticHash
DataFrame.semanticHash()Returns a hash code of the logical query plan against this DataFrame.
New in version 3.1.0.
Notes
Unlike the standard hash code, the hash is calculated against the query plan simplified by tolerating the cosmeticdifferences such as attribute names.
This API is a developer API.
Examples
>>> spark.range(10).selectExpr("id as col0").semanticHash()1855039936>>> spark.range(10).selectExpr("id as col1").semanticHash()1855039936
pyspark.sql.DataFrame.show
DataFrame.show(n=20, truncate=True, vertical=False)Prints the first n rows to the console.
New in version 1.3.0.
Parameters
n [int, optional] Number of rows to show.
truncate [bool, optional] If set to True, truncate strings longer than 20 chars by default. If setto a number greater than one, truncates long strings to length truncate and align cellsright.
vertical [bool, optional] If set to True, print output rows vertically (one line per column value).
102 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> dfDataFrame[age: int, name: string]>>> df.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+>>> df.show(truncate=3)+---+----+|age|name|+---+----+| 2| Ali|| 5| Bob|+---+----+>>> df.show(vertical=True)-RECORD 0-----age | 2name | Alice
-RECORD 1-----age | 5name | Bob
pyspark.sql.DataFrame.sort
DataFrame.sort(*cols, **kwargs)Returns a new DataFrame sorted by the specified column(s).
New in version 1.3.0.
Parameters
cols [str, list, or Column, optional] list of Column or column names to sort by.
Other Parameters
ascending [bool or list, optional] boolean or list of boolean (default True). Sort ascending vs.descending. Specify list for multiple sort orders. If a list is specified, length of the list mustequal length of the cols.
Examples
>>> df.sort(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.sort("age", ascending=False).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.orderBy(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> from pyspark.sql.functions import *>>> df.sort(asc("age")).collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df.orderBy(desc("age"), "name").collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]
(continues on next page)
3.1. Spark SQL 103
pyspark Documentation, Release master
(continued from previous page)
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]
pyspark.sql.DataFrame.sortWithinPartitions
DataFrame.sortWithinPartitions(*cols, **kwargs)Returns a new DataFrame with each partition sorted by the specified column(s).
New in version 1.6.0.
Parameters
cols [str, list or Column, optional] list of Column or column names to sort by.
Other Parameters
ascending [bool or list, optional] boolean or list of boolean (default True). Sort ascending vs.descending. Specify list for multiple sort orders. If a list is specified, length of the list mustequal length of the cols.
Examples
>>> df.sortWithinPartitions("age", ascending=False).show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+
pyspark.sql.DataFrame.stat
property DataFrame.statReturns a DataFrameStatFunctions for statistic functions.
New in version 1.4.
pyspark.sql.DataFrame.storageLevel
property DataFrame.storageLevelGet the DataFrame’s current storage level.
New in version 2.1.0.
104 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.storageLevelStorageLevel(False, False, False, False, 1)>>> df.cache().storageLevelStorageLevel(True, True, False, True, 1)>>> df2.persist(StorageLevel.DISK_ONLY_2).storageLevelStorageLevel(True, False, False, False, 2)
pyspark.sql.DataFrame.subtract
DataFrame.subtract(other)Return a new DataFrame containing rows in this DataFrame but not in another DataFrame.
This is equivalent to EXCEPT DISTINCT in SQL.
New in version 1.3.
pyspark.sql.DataFrame.summary
DataFrame.summary(*statistics)Computes specified statistics for numeric and string columns. Available statistics are: - count - mean - stddev -min - max - arbitrary approximate percentiles specified as a percentage (e.g., 75%)
If no statistics are given, this function computes count, mean, stddev, min, approximate quartiles (percentiles at25%, 50%, and 75%), and max.
New in version 2.3.0.
See also:
DataFrame.display
Notes
This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.
Examples
>>> df.summary().show()+-------+------------------+-----+|summary| age| name|+-------+------------------+-----+| count| 2| 2|| mean| 3.5| null|| stddev|2.1213203435596424| null|| min| 2|Alice|| 25%| 2| null|| 50%| 2| null|| 75%| 5| null|| max| 5| Bob|+-------+------------------+-----+
3.1. Spark SQL 105
pyspark Documentation, Release master
>>> df.summary("count", "min", "25%", "75%", "max").show()+-------+---+-----+|summary|age| name|+-------+---+-----+| count| 2| 2|| min| 2|Alice|| 25%| 2| null|| 75%| 5| null|| max| 5| Bob|+-------+---+-----+
To do a summary for specific columns first select them:
>>> df.select("age", "name").summary("count").show()+-------+---+----+|summary|age|name|+-------+---+----+| count| 2| 2|+-------+---+----+
pyspark.sql.DataFrame.tail
DataFrame.tail(num)Returns the last num rows as a list of Row .
Running tail requires moving data into the application’s driver process, and doing so with a very large num cancrash the driver process with OutOfMemoryError.
New in version 3.0.0.
Examples
>>> df.tail(1)[Row(age=5, name='Bob')]
pyspark.sql.DataFrame.take
DataFrame.take(num)Returns the first num rows as a list of Row .
New in version 1.3.0.
Examples
>>> df.take(2)[Row(age=2, name='Alice'), Row(age=5, name='Bob')]
106 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame.toDF
DataFrame.toDF(*cols)Returns a new DataFrame that with new specified column names
Parameters
cols [str] new column names
Examples
>>> df.toDF('f1', 'f2').collect()[Row(f1=2, f2='Alice'), Row(f1=5, f2='Bob')]
pyspark.sql.DataFrame.toJSON
DataFrame.toJSON(use_unicode=True)Converts a DataFrame into a RDD of string.
Each row is turned into a JSON document as one element in the returned RDD.
New in version 1.3.0.
Examples
>>> df.toJSON().first()'{"age":2,"name":"Alice"}'
pyspark.sql.DataFrame.toLocalIterator
DataFrame.toLocalIterator(prefetchPartitions=False)Returns an iterator that contains all of the rows in this DataFrame. The iterator will consume as much memoryas the largest partition in this DataFrame. With prefetch it may consume up to the memory of the 2 largestpartitions.
New in version 2.0.0.
Parameters
prefetchPartitions [bool, optional] If Spark should pre-fetch the next partition before it isneeded.
Examples
>>> list(df.toLocalIterator())[Row(age=2, name='Alice'), Row(age=5, name='Bob')]
3.1. Spark SQL 107
pyspark Documentation, Release master
pyspark.sql.DataFrame.toPandas
DataFrame.toPandas()Returns the contents of this DataFrame as Pandas pandas.DataFrame.
This is only available if Pandas is installed and available.
New in version 1.3.0.
Notes
This method should only be used if the resulting Pandas’s DataFrame is expected to be small, as all the datais loaded into the driver’s memory.
Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
Examples
>>> df.toPandas()age name
0 2 Alice1 5 Bob
pyspark.sql.DataFrame.transform
DataFrame.transform(func)Returns a new DataFrame. Concise syntax for chaining custom transformations.
New in version 3.0.0.
Parameters
func [function] a function that takes and returns a DataFrame.
Examples
>>> from pyspark.sql.functions import col>>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"])>>> def cast_all_to_int(input_df):... return input_df.select([col(col_name).cast("int") for col_name in input_→˓df.columns])>>> def sort_columns_asc(input_df):... return input_df.select(*sorted(input_df.columns))>>> df.transform(cast_all_to_int).transform(sort_columns_asc).show()+-----+---+|float|int|+-----+---+| 1| 1|| 2| 2|+-----+---+
108 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame.union
DataFrame.union(other)Return a new DataFrame containing union of rows in this and another DataFrame.
This is equivalent to UNION ALL in SQL. To do a SQL-style set union (that does deduplication of elements),use this function followed by distinct().
Also as standard in SQL, this function resolves columns by position (not by name).
New in version 2.0.
pyspark.sql.DataFrame.unionAll
DataFrame.unionAll(other)Return a new DataFrame containing union of rows in this and another DataFrame.
This is equivalent to UNION ALL in SQL. To do a SQL-style set union (that does deduplication of elements),use this function followed by distinct().
Also as standard in SQL, this function resolves columns by position (not by name).
New in version 1.3.
pyspark.sql.DataFrame.unionByName
DataFrame.unionByName(other, allowMissingColumns=False)Returns a new DataFrame containing union of rows in this and another DataFrame.
This is different from both UNION ALL and UNION DISTINCT in SQL. To do a SQL-style set union (that doesdeduplication of elements), use this function followed by distinct().
New in version 2.3.0.
Examples
The difference between this function and union() is that this function resolves columns by name (not byposition):
>>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"])>>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"])>>> df1.unionByName(df2).show()+----+----+----+|col0|col1|col2|+----+----+----+| 1| 2| 3|| 6| 4| 5|+----+----+----+
When the parameter allowMissingColumns is True, the set of column names in this and other DataFramecan differ; missing columns will be filled with null. Further, the missing columns of this DataFrame will beadded at the end in the schema of the union result:
>>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"])>>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col3"])>>> df1.unionByName(df2, allowMissingColumns=True).show()
(continues on next page)
3.1. Spark SQL 109
pyspark Documentation, Release master
(continued from previous page)
+----+----+----+----+|col0|col1|col2|col3|+----+----+----+----+| 1| 2| 3|null||null| 4| 5| 6|+----+----+----+----+
Changed in version 3.1.0: Added optional argument allowMissingColumns to specify whether to allow missingcolumns.
pyspark.sql.DataFrame.unpersist
DataFrame.unpersist(blocking=False)Marks the DataFrame as non-persistent, and remove all blocks for it from memory and disk.
New in version 1.3.0.
Notes
blocking default has changed to False to match Scala in 2.0.
pyspark.sql.DataFrame.where
DataFrame.where(condition)where() is an alias for filter().
New in version 1.3.
pyspark.sql.DataFrame.withColumn
DataFrame.withColumn(colName, col)Returns a new DataFrame by adding a column or replacing the existing column that has the same name.
The column expression must be an expression over this DataFrame; attempting to add a column from someother DataFrame will raise an error.
New in version 1.3.0.
Parameters
colName [str] string, name of the new column.
col [Column] a Column expression for the new column.
110 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This method introduces a projection internally. Therefore, calling it multiple times, for instance, via loops inorder to add multiple columns can generate big plans which can cause performance issues and even StackOver-flowException. To avoid this, use select() with the multiple columns at once.
Examples
>>> df.withColumn('age2', df.age + 2).collect()[Row(age=2, name='Alice', age2=4), Row(age=5, name='Bob', age2=7)]
pyspark.sql.DataFrame.withColumnRenamed
DataFrame.withColumnRenamed(existing, new)Returns a new DataFrame by renaming an existing column. This is a no-op if schema doesn’t contain thegiven column name.
New in version 1.3.0.
Parameters
existing [str] string, name of the existing column to rename.
new [str] string, new name of the column.
Examples
>>> df.withColumnRenamed('age', 'age2').collect()[Row(age2=2, name='Alice'), Row(age2=5, name='Bob')]
pyspark.sql.DataFrame.withWatermark
DataFrame.withWatermark(eventTime, delayThreshold)Defines an event time watermark for this DataFrame. A watermark tracks a point in time before which weassume no more late data is going to arrive.
Spark will use this watermark for several purposes:
• To know when a given time window aggregation can be finalized and thus can be emitted when usingoutput modes that do not allow updates.
• To minimize the amount of state that we need to keep for on-going aggregations.
The current watermark is computed by looking at the MAX(eventTime) seen across all of the partitions in thequery minus a user specified delayThreshold. Due to the cost of coordinating this value across partitions, theactual watermark used is only guaranteed to be at least delayThreshold behind the actual event time. In somecases we may still process records that arrive more than delayThreshold late.
New in version 2.1.0.
Parameters
eventTime [str or Column] the name of the column that contains the event time of the row.
delayThreshold [str] the minimum delay to wait to data to arrive late, relative to the latestrecord that has been processed in the form of an interval (e.g. “1 minute” or “5 hours”).
3.1. Spark SQL 111
pyspark Documentation, Release master
Notes
This API is evolving.
>>> from pyspark.sql.functions import timestamp_seconds>>> sdf.select(... 'name',... timestamp_seconds(sdf.time).alias('time')).withWatermark('time', '10→˓minutes')DataFrame[name: string, time: timestamp]
pyspark.sql.DataFrame.write
property DataFrame.writeInterface for saving the content of the non-streaming DataFrame out into external storage.
New in version 1.4.0.
Returns
DataFrameWriter
pyspark.sql.DataFrame.writeStream
property DataFrame.writeStreamInterface for saving the content of the streaming DataFrame out into external storage.
New in version 2.0.0.
Returns
DataStreamWriter
Notes
This API is evolving.
pyspark.sql.DataFrame.writeTo
DataFrame.writeTo(table)Create a write configuration builder for v2 sources.
This builder is used to configure and execute write operations.
For example, to append or create or replace existing tables.
New in version 3.1.0.
112 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.writeTo("catalog.db.table").append()>>> df.writeTo(... "catalog.db.table"... ).partitionedBy("col").createOrReplace()
pyspark.sql.DataFrameNaFunctions.drop
DataFrameNaFunctions.drop(how='any', thresh=None, subset=None)Returns a new DataFrame omitting rows with null values. DataFrame.dropna() andDataFrameNaFunctions.drop() are aliases of each other.
New in version 1.3.1.
Parameters
how [str, optional] ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a rowonly if all its values are null.
thresh: int, optional default None If specified, drop rows that have less than thresh non-nullvalues. This overwrites the how parameter.
subset [str, tuple or list, optional] optional list of column names to consider.
Examples
>>> df4.na.drop().show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|+---+------+-----+
pyspark.sql.DataFrameNaFunctions.fill
DataFrameNaFunctions.fill(value, subset=None)Replace null values, alias for na.fill(). DataFrame.fillna() and DataFrameNaFunctions.fill() are aliases of each other.
New in version 1.3.1.
Parameters
value [int, float, string, bool or dict] Value to replace null values with. If the value is a dict, thensubset is ignored and value must be a mapping from column name (string) to replacementvalue. The replacement value must be an int, float, boolean, or string.
subset [str, tuple or list, optional] optional list of column names to consider. Columns specifiedin subset that do not have matching data type are ignored. For example, if value is a string,and subset contains a non-string column, then the non-string column is simply ignored.
3.1. Spark SQL 113
pyspark Documentation, Release master
Examples
>>> df4.na.fill(50).show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|| 5| 50| Bob|| 50| 50| Tom|| 50| 50| null|+---+------+-----+
>>> df5.na.fill(False).show()+----+-------+-----+| age| name| spy|+----+-------+-----+| 10| Alice|false|| 5| Bob|false||null|Mallory| true|+----+-------+-----+
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()+---+------+-------+|age|height| name|+---+------+-------+| 10| 80| Alice|| 5| null| Bob|| 50| null| Tom|| 50| null|unknown|+---+------+-------+
pyspark.sql.DataFrameNaFunctions.replace
DataFrameNaFunctions.replace(to_replace, value=<no value>, subset=None)Returns a new DataFrame replacing a value with another value. DataFrame.replace() andDataFrameNaFunctions.replace() are aliases of each other. Values to_replace and value must havethe same type and can only be numerics, booleans, or strings. Value can have None. When replacing, the newvalue will be cast to the type of the existing column. For numeric replacements all values to be replaced shouldhave unique floating point representation. In case of conflicts (for example with {42: -1, 42.0: 1}) and arbitraryreplacement will be used.
New in version 1.4.0.
Parameters
to_replace [bool, int, float, string, list or dict] Value to be replaced. If the value is a dict, thenvalue is ignored or can be omitted, and to_replace must be a mapping between a value anda replacement.
value [bool, int, float, string or None, optional] The replacement value must be a bool, int, float,string or None. If value is a list, value should be of the same length and type as to_replace.If value is a scalar and to_replace is a sequence, then value is used as a replacement for eachitem in to_replace.
subset [list, optional] optional list of column names to consider. Columns specified in subsetthat do not have matching data type are ignored. For example, if value is a string, and subsetcontains a non-string column, then the non-string column is simply ignored.
114 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df4.na.replace(10, 20).show()+----+------+-----+| age|height| name|+----+------+-----+| 20| 80|Alice|| 5| null| Bob||null| null| Tom||null| null| null|+----+------+-----+
>>> df4.na.replace('Alice', None).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+
>>> df4.na.replace({'Alice': None}).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+
>>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()+----+------+----+| age|height|name|+----+------+----+| 10| 80| A|| 5| null| B||null| null| Tom||null| null|null|+----+------+----+
pyspark.sql.DataFrameStatFunctions.approxQuantile
DataFrameStatFunctions.approxQuantile(col, probabilities, relativeError)Calculates the approximate quantiles of numerical columns of a DataFrame.
The result of this algorithm has the following deterministic bound: If the DataFrame has N elements andif we request the quantile at probability p up to error err, then the algorithm will return a sample x from theDataFrame so that the exact rank of x is close to (p * N). More precisely,
floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). Thealgorithm was first present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation ofQuantile Summaries]] by Greenwald and Khanna.
3.1. Spark SQL 115
pyspark Documentation, Release master
Note that null values will be ignored in numerical columns before calculation. For columns only containing nullvalues, an empty list is returned.
New in version 2.0.0.
Parameters
col: str, tuple or list Can be a single column name, or a list of names for multiple columns.
Changed in version 2.2: Added support for multiple columns.
probabilities [list or tuple] a list of quantile probabilities Each number must belong to [0, 1].For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
relativeError [float] The relative target precision to achieve (>= 0). If set to zero, the exactquantiles are computed, which could be very expensive. Note that values greater than 1 areaccepted but give the same result as 1.
Returns
list the approximate quantiles at the given probabilities. If the input col is a string, the output isa list of floats. If the input col is a list or tuple of strings, the output is also a list, but eachelement in it is a list of floats, i.e., the output is a list of list of floats.
pyspark.sql.DataFrameStatFunctions.corr
DataFrameStatFunctions.corr(col1, col2, method=None)Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports thePearson Correlation Coefficient. DataFrame.corr() and DataFrameStatFunctions.corr() arealiases of each other.
New in version 1.4.0.
Parameters
col1 [str] The name of the first column
col2 [str] The name of the second column
method [str, optional] The correlation method. Currently only supports “pearson”
pyspark.sql.DataFrameStatFunctions.cov
DataFrameStatFunctions.cov(col1, col2)Calculate the sample covariance for the given columns, specified by their names, as a double value.DataFrame.cov() and DataFrameStatFunctions.cov() are aliases.
New in version 1.4.0.
Parameters
col1 [str] The name of the first column
col2 [str] The name of the second column
116 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrameStatFunctions.crosstab
DataFrameStatFunctions.crosstab(col1, col2)Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number ofdistinct values for each column should be less than 1e4. At most 1e6 non-zero pair frequencies will be returned.The first column of each row will be the distinct values of col1 and the column names will be the distinct valuesof col2. The name of the first column will be $col1_$col2. Pairs that have no occurrences will have zero as theircounts. DataFrame.crosstab() and DataFrameStatFunctions.crosstab() are aliases.
New in version 1.4.0.
Parameters
col1 [str] The name of the first column. Distinct items will make the first item of each row.
col2 [str] The name of the second column. Distinct items will make the column names of theDataFrame.
pyspark.sql.DataFrameStatFunctions.freqItems
DataFrameStatFunctions.freqItems(cols, support=None)Finding frequent items for columns, possibly with false positives. Using the frequent element count algo-rithm described in “https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou”.DataFrame.freqItems() and DataFrameStatFunctions.freqItems() are aliases.
New in version 1.4.0.
Parameters
cols [list or tuple] Names of the columns to calculate frequent items for as a list or tuple ofstrings.
support [float, optional] The frequency with which to consider an item ‘frequent’. Default is1%. The support must be greater than 1e-4.
Notes
This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.
pyspark.sql.DataFrameStatFunctions.sampleBy
DataFrameStatFunctions.sampleBy(col, fractions, seed=None)Returns a stratified sample without replacement based on the fraction given on each stratum.
New in version 1.5.0.
Parameters
col [Column or str] column that defines strata
Changed in version 3.0: Added sampling by a column of Column
fractions [dict] sampling fraction for each stratum. If a stratum is not specified, we treat itsfraction as zero.
seed [int, optional] random seed
Returns
3.1. Spark SQL 117
pyspark Documentation, Release master
a new DataFrame that represents the stratified sample
Examples
>>> from pyspark.sql.functions import col>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)>>> sampled.groupBy("key").count().orderBy("key").show()+---+-----+|key|count|+---+-----+| 0| 3|| 1| 6|+---+-----+>>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()33
3.1.5 Column APIs
Column.alias(*alias, **kwargs) Returns this column aliased with a new name or names(in the case of expressions that return more than onecolumn, such as explode).
Column.asc() Returns a sort expression based on ascending order ofthe column.
Column.asc_nulls_first() Returns a sort expression based on ascending order ofthe column, and null values return before non-null val-ues.
Column.asc_nulls_last() Returns a sort expression based on ascending order ofthe column, and null values appear after non-null values.
Column.astype(dataType) astype() is an alias for cast().Column.between(lowerBound, upperBound) A boolean expression that is evaluated to true if the
value of this expression is between the given columns.Column.bitwiseAND(other) Compute bitwise AND of this expression with another
expression.Column.bitwiseOR(other) Compute bitwise OR of this expression with another ex-
pression.Column.bitwiseXOR(other) Compute bitwise XOR of this expression with another
expression.Column.cast(dataType) Convert the column into type dataType.Column.contains(other) Contains the other element.Column.desc() Returns a sort expression based on the descending order
of the column.Column.desc_nulls_first() Returns a sort expression based on the descending order
of the column, and null values appear before non-nullvalues.
Column.desc_nulls_last() Returns a sort expression based on the descending or-der of the column, and null values appear after non-nullvalues.
Column.dropFields(*fieldNames) An expression that drops fields in StructType byname.
continues on next page
118 Chapter 3. API Reference
pyspark Documentation, Release master
Table 17 – continued from previous pageColumn.endswith(other) String ends with.Column.eqNullSafe(other) Equality test that is safe for null values.Column.getField(name) An expression that gets a field by name in a StructField.Column.getItem(key) An expression that gets an item at position ordinal
out of a list, or gets an item by key out of a dict.Column.isNotNull() True if the current expression is NOT null.Column.isNull() True if the current expression is null.Column.isin(*cols) A boolean expression that is evaluated to true if the
value of this expression is contained by the evaluatedvalues of the arguments.
Column.like(other) SQL like expression.Column.name(*alias, **kwargs) name() is an alias for alias().Column.otherwise(value) Evaluates a list of conditions and returns one of multiple
possible result expressions.Column.over(window) Define a windowing column.Column.rlike(other) SQL RLIKE expression (LIKE with Regex).Column.startswith(other) String starts with.Column.substr(startPos, length) Return a Column which is a substring of the column.Column.when(condition, value) Evaluates a list of conditions and returns one of multiple
possible result expressions.Column.withField(fieldName, col) An expression that adds/replaces a field in
StructType by name.
pyspark.sql.Column.alias
Column.alias(*alias, **kwargs)Returns this column aliased with a new name or names (in the case of expressions that return more than onecolumn, such as explode).
New in version 1.3.0.
Parameters
alias [str] desired column names (collects all positional arguments passed)
Other Parameters
metadata: dict a dict of information to be stored in metadata attribute of the correspondingStructField (optional, keyword only argument)
Changed in version 2.2.0: Added optional metadata argument.
Examples
>>> df.select(df.age.alias("age2")).collect()[Row(age2=2), Row(age2=5)]>>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata[→˓'max']99
3.1. Spark SQL 119
pyspark Documentation, Release master
pyspark.sql.Column.asc
Column.asc()Returns a sort expression based on ascending order of the column.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])>>> df.select(df.name).orderBy(df.name.asc()).collect()[Row(name='Alice'), Row(name='Tom')]
pyspark.sql.Column.asc_nulls_first
Column.asc_nulls_first()Returns a sort expression based on ascending order of the column, and null values return before non-null values.
New in version 2.4.0.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()[Row(name=None), Row(name='Alice'), Row(name='Tom')]
pyspark.sql.Column.asc_nulls_last
Column.asc_nulls_last()Returns a sort expression based on ascending order of the column, and null values appear after non-null values.
New in version 2.4.0.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()[Row(name='Alice'), Row(name='Tom'), Row(name=None)]
120 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.Column.astype
Column.astype(dataType)astype() is an alias for cast().
New in version 1.4.
pyspark.sql.Column.between
Column.between(lowerBound, upperBound)A boolean expression that is evaluated to true if the value of this expression is between the given columns.
New in version 1.3.0.
Examples
>>> df.select(df.name, df.age.between(2, 4)).show()+-----+---------------------------+| name|((age >= 2) AND (age <= 4))|+-----+---------------------------+|Alice| true|| Bob| false|+-----+---------------------------+
pyspark.sql.Column.bitwiseAND
Column.bitwiseAND(other)Compute bitwise AND of this expression with another expression.
Parameters
other a value or Column to calculate bitwise and(&) with this Column.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(a=170, b=75)])>>> df.select(df.a.bitwiseAND(df.b)).collect()[Row((a & b)=10)]
pyspark.sql.Column.bitwiseOR
Column.bitwiseOR(other)Compute bitwise OR of this expression with another expression.
Parameters
other a value or Column to calculate bitwise or(|) with this Column.
3.1. Spark SQL 121
pyspark Documentation, Release master
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(a=170, b=75)])>>> df.select(df.a.bitwiseOR(df.b)).collect()[Row((a | b)=235)]
pyspark.sql.Column.bitwiseXOR
Column.bitwiseXOR(other)Compute bitwise XOR of this expression with another expression.
Parameters
other a value or Column to calculate bitwise xor(^) with this Column.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(a=170, b=75)])>>> df.select(df.a.bitwiseXOR(df.b)).collect()[Row((a ^ b)=225)]
pyspark.sql.Column.cast
Column.cast(dataType)Convert the column into type dataType.
New in version 1.3.0.
Examples
>>> df.select(df.age.cast("string").alias('ages')).collect()[Row(ages='2'), Row(ages='5')]>>> df.select(df.age.cast(StringType()).alias('ages')).collect()[Row(ages='2'), Row(ages='5')]
pyspark.sql.Column.contains
Column.contains(other)Contains the other element. Returns a boolean Column based on a string match.
Parameters
other string in line. A value as a literal or a Column.
122 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.filter(df.name.contains('o')).collect()[Row(age=5, name='Bob')]
pyspark.sql.Column.desc
Column.desc()Returns a sort expression based on the descending order of the column.
New in version 2.4.0.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])>>> df.select(df.name).orderBy(df.name.desc()).collect()[Row(name='Tom'), Row(name='Alice')]
pyspark.sql.Column.desc_nulls_first
Column.desc_nulls_first()Returns a sort expression based on the descending order of the column, and null values appear before non-nullvalues.
New in version 2.4.0.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()[Row(name=None), Row(name='Tom'), Row(name='Alice')]
pyspark.sql.Column.desc_nulls_last
Column.desc_nulls_last()Returns a sort expression based on the descending order of the column, and null values appear after non-nullvalues.
New in version 2.4.0.
3.1. Spark SQL 123
pyspark Documentation, Release master
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()[Row(name='Tom'), Row(name='Alice'), Row(name=None)]
pyspark.sql.Column.dropFields
Column.dropFields(*fieldNames)An expression that drops fields in StructType by name.
New in version 3.1.0.
Examples
>>> from pyspark.sql import Row>>> from pyspark.sql.functions import col, lit>>> df = spark.createDataFrame([... Row(a=Row(b=1, c=2, d=3, e=Row(f=4, g=5, h=6)))])>>> df.withColumn('a', df['a'].dropFields('b')).show()+-----------------+| a|+-----------------+|{2, 3, {4, 5, 6}}|+-----------------+
>>> df.withColumn('a', df['a'].dropFields('b', 'c')).show()+--------------+| a|+--------------+|{3, {4, 5, 6}}|+--------------+
This method supports dropping multiple nested fields directly e.g.
>>> df.withColumn("a", col("a").dropFields("e.g", "e.h")).show()+--------------+| a|+--------------+|{1, 2, 3, {4}}|+--------------+
However, if you are going to add/replace multiple nested fields, it is preferred to extract out the nested structbefore adding/replacing multiple fields e.g.
>>> df.select(col("a").withField(... "e", col("a.e").dropFields("g", "h")).alias("a")... ).show()+--------------+| a|+--------------+|{1, 2, 3, {4}}|+--------------+
124 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.Column.endswith
Column.endswith(other)String ends with. Returns a boolean Column based on a string match.
Parameters
other [Column or str] string at end of line (do not use a regex $)
Examples
>>> df.filter(df.name.endswith('ice')).collect()[Row(age=2, name='Alice')]>>> df.filter(df.name.endswith('ice$')).collect()[]
pyspark.sql.Column.eqNullSafe
Column.eqNullSafe(other)Equality test that is safe for null values.
New in version 2.3.0.
Parameters
other a value or Column
Notes
Unlike Pandas, PySpark doesn’t consider NaN values to be NULL. See the NaN Semantics for details.
Examples
>>> from pyspark.sql import Row>>> df1 = spark.createDataFrame([... Row(id=1, value='foo'),... Row(id=2, value=None)... ])>>> df1.select(... df1['value'] == 'foo',... df1['value'].eqNullSafe('foo'),... df1['value'].eqNullSafe(None)... ).show()+-------------+---------------+----------------+|(value = foo)|(value <=> foo)|(value <=> NULL)|+-------------+---------------+----------------+| true| true| false|| null| false| true|+-------------+---------------+----------------+>>> df2 = spark.createDataFrame([... Row(value = 'bar'),... Row(value = None)... ])>>> df1.join(df2, df1["value"] == df2["value"]).count()
(continues on next page)
3.1. Spark SQL 125
pyspark Documentation, Release master
(continued from previous page)
0>>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()1>>> df2 = spark.createDataFrame([... Row(id=1, value=float('NaN')),... Row(id=2, value=42.0),... Row(id=3, value=None)... ])>>> df2.select(... df2['value'].eqNullSafe(None),... df2['value'].eqNullSafe(float('NaN')),... df2['value'].eqNullSafe(42.0)... ).show()+----------------+---------------+----------------+|(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)|+----------------+---------------+----------------+| false| true| false|| false| false| true|| true| false| false|+----------------+---------------+----------------+
pyspark.sql.Column.getField
Column.getField(name)An expression that gets a field by name in a StructField.
New in version 1.3.0.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(r=Row(a=1, b="b"))])>>> df.select(df.r.getField("b")).show()+---+|r.b|+---+| b|+---+>>> df.select(df.r.a).show()+---+|r.a|+---+| 1|+---+
126 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.Column.getItem
Column.getItem(key)An expression that gets an item at position ordinal out of a list, or gets an item by key out of a dict.
New in version 1.3.0.
Examples
>>> df = spark.createDataFrame([([1, 2], {"key": "value"})], ["l", "d"])>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()+----+------+|l[0]|d[key]|+----+------+| 1| value|+----+------+
pyspark.sql.Column.isNotNull
Column.isNotNull()True if the current expression is NOT null.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Alice',→˓height=None)])>>> df.filter(df.height.isNotNull()).collect()[Row(name='Tom', height=80)]
pyspark.sql.Column.isNull
Column.isNull()True if the current expression is null.
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Alice',→˓height=None)])>>> df.filter(df.height.isNull()).collect()[Row(name='Alice', height=None)]
3.1. Spark SQL 127
pyspark Documentation, Release master
pyspark.sql.Column.isin
Column.isin(*cols)A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated valuesof the arguments.
New in version 1.5.0.
Examples
>>> df[df.name.isin("Bob", "Mike")].collect()[Row(age=5, name='Bob')]>>> df[df.age.isin([1, 2, 3])].collect()[Row(age=2, name='Alice')]
pyspark.sql.Column.like
Column.like(other)SQL like expression. Returns a boolean Column based on a SQL LIKE match.
Parameters
other [str] a SQL LIKE pattern
See also:
pyspark.sql.Column.rlike
Examples
>>> df.filter(df.name.like('Al%')).collect()[Row(age=2, name='Alice')]
pyspark.sql.Column.name
Column.name(*alias, **kwargs)name() is an alias for alias().
New in version 2.0.
pyspark.sql.Column.otherwise
Column.otherwise(value)Evaluates a list of conditions and returns one of multiple possible result expressions. If Column.otherwise() is not invoked, None is returned for unmatched conditions.
New in version 1.4.0.
Parameters
value a literal value, or a Column expression.
See also:
128 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.when
Examples
>>> from pyspark.sql import functions as F>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()+-----+-------------------------------------+| name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|+-----+-------------------------------------+|Alice| 0|| Bob| 1|+-----+-------------------------------------+
pyspark.sql.Column.over
Column.over(window)Define a windowing column.
New in version 1.4.0.
Parameters
window [WindowSpec]
Returns
Column
Examples
>>> from pyspark.sql import Window>>> window = Window.partitionBy("name").orderBy("age") .→˓rowsBetween(Window.unboundedPreceding, Window.currentRow)>>> from pyspark.sql.functions import rank, min>>> from pyspark.sql.functions import desc>>> df.withColumn("rank", rank().over(window)) .withColumn("min",→˓min('age').over(window)).sort(desc("age")).show()+---+-----+----+---+|age| name|rank|min|+---+-----+----+---+| 5| Bob| 1| 5|| 2|Alice| 1| 2|+---+-----+----+---+
pyspark.sql.Column.rlike
Column.rlike(other)SQL RLIKE expression (LIKE with Regex). Returns a boolean Column based on a regex match.
Parameters
other [str] an extended regex expression
3.1. Spark SQL 129
pyspark Documentation, Release master
Examples
>>> df.filter(df.name.rlike('ice$')).collect()[Row(age=2, name='Alice')]
pyspark.sql.Column.startswith
Column.startswith(other)String starts with. Returns a boolean Column based on a string match.
Parameters
other [Column or str] string at start of line (do not use a regex ^)
Examples
>>> df.filter(df.name.startswith('Al')).collect()[Row(age=2, name='Alice')]>>> df.filter(df.name.startswith('^Al')).collect()[]
pyspark.sql.Column.substr
Column.substr(startPos, length)Return a Column which is a substring of the column.
New in version 1.3.0.
Parameters
startPos [Column or int] start position
length [Column or int] length of the substring
Examples
>>> df.select(df.name.substr(1, 3).alias("col")).collect()[Row(col='Ali'), Row(col='Bob')]
pyspark.sql.Column.when
Column.when(condition, value)Evaluates a list of conditions and returns one of multiple possible result expressions. If Column.otherwise() is not invoked, None is returned for unmatched conditions.
New in version 1.4.0.
Parameters
condition [Column] a boolean Column expression.
value a literal value, or a Column expression.
See also:
130 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.when
Examples
>>> from pyspark.sql import functions as F>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).→˓show()+-----+------------------------------------------------------------+| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|+-----+------------------------------------------------------------+|Alice| -1|| Bob| 1|+-----+------------------------------------------------------------+
pyspark.sql.Column.withField
Column.withField(fieldName, col)An expression that adds/replaces a field in StructType by name.
New in version 3.1.0.
Examples
>>> from pyspark.sql import Row>>> from pyspark.sql.functions import lit>>> df = spark.createDataFrame([Row(a=Row(b=1, c=2))])>>> df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show()+---+| b|+---+| 3|+---+>>> df.withColumn('a', df['a'].withField('d', lit(4))).select('a.d').show()+---+| d|+---+| 4|+---+
3.1.6 Data Types
ArrayType(elementType[, containsNull]) Array data type.BinaryType() Binary (byte array) data type.BooleanType() Boolean data type.ByteType() Byte data type, i.e.DataType() Base class for data types.DateType() Date (datetime.date) data type.DecimalType([precision, scale]) Decimal (decimal.Decimal) data type.DoubleType() Double data type, representing double precision floats.FloatType() Float data type, representing single precision floats.
continues on next page
3.1. Spark SQL 131
pyspark Documentation, Release master
Table 18 – continued from previous pageIntegerType() Int data type, i.e.LongType() Long data type, i.e.MapType(keyType, valueType[, valueContainsNull]) Map data type.NullType() Null type.ShortType() Short data type, i.e.StringType() String data type.StructField(name, dataType[, nullable, metadata]) A field in StructType.StructType([fields]) Struct type, consisting of a list of StructField.TimestampType() Timestamp (datetime.datetime) data type.
ArrayType
class pyspark.sql.types.ArrayType(elementType, containsNull=True)Array data type.
Parameters
elementType [DataType] DataType of each element in the array.
containsNull [bool, optional] whether the array can contain null (None) values.
Examples
>>> ArrayType(StringType()) == ArrayType(StringType(), True)True>>> ArrayType(StringType(), False) == ArrayType(StringType())False
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
132 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
classmethod fromJson(json)
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
BinaryType
class pyspark.sql.types.BinaryTypeBinary (byte array) data type.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
3.1. Spark SQL 133
pyspark Documentation, Release master
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
BooleanType
class pyspark.sql.types.BooleanTypeBoolean data type.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
ByteType
class pyspark.sql.types.ByteTypeByte data type, i.e. a signed integer in a single byte.
134 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
DataType
class pyspark.sql.types.DataTypeBase class for data types.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
3.1. Spark SQL 135
pyspark Documentation, Release master
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
DateType
class pyspark.sql.types.DateTypeDate (datetime.date) data type.
Methods
fromInternal(v) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(d) Converts a Python object into an internal SQL object.typeName()
Attributes
EPOCH_ORDINAL
Methods Documentation
fromInternal(v)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
136 Chapter 3. API Reference
pyspark Documentation, Release master
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(d)Converts a Python object into an internal SQL object.
classmethod typeName()
Attributes Documentation
EPOCH_ORDINAL = 719163
DecimalType
class pyspark.sql.types.DecimalType(precision=10, scale=0)Decimal (decimal.Decimal) data type.
The DecimalType must have fixed precision (the maximum total number of digits) and scale (the number ofdigits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99].
The precision can be up to 38, the scale must be less or equal to precision.
When creating a DecimalType, the default precision and scale is (10, 0). When inferring schema from deci-mal.Decimal objects, it will be DecimalType(38, 18).
Parameters
precision [int, optional] the maximum (i.e. total) number of digits (default: 10)
scale [int, optional] the number of digits on right side of dot. (default: 0)
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
3.1. Spark SQL 137
pyspark Documentation, Release master
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
DoubleType
class pyspark.sql.types.DoubleTypeDouble data type, representing double precision floats.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
138 Chapter 3. API Reference
pyspark Documentation, Release master
FloatType
class pyspark.sql.types.FloatTypeFloat data type, representing single precision floats.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
IntegerType
class pyspark.sql.types.IntegerTypeInt data type, i.e. a signed 32-bit integer.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.continues on next page
3.1. Spark SQL 139
pyspark Documentation, Release master
Table 29 – continued from previous pagesimpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
LongType
class pyspark.sql.types.LongTypeLong data type, i.e. a signed 64-bit integer.
If the values are beyond the range of [-9223372036854775808, 9223372036854775807], please useDecimalType.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
140 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
MapType
class pyspark.sql.types.MapType(keyType, valueType, valueContainsNull=True)Map data type.
Parameters
keyType [DataType] DataType of the keys in the map.
valueType [DataType] DataType of the values in the map.
valueContainsNull [bool, optional] indicates whether values can contain null (None) values.
Notes
Keys in a map data type are not allowed to be null (None).
Examples
>>> (MapType(StringType(), IntegerType())... == MapType(StringType(), IntegerType(), True))True>>> (MapType(StringType(), IntegerType(), False)... == MapType(StringType(), FloatType()))False
3.1. Spark SQL 141
pyspark Documentation, Release master
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
classmethod fromJson(json)
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
NullType
class pyspark.sql.types.NullTypeNull type.
The data type representing None, used for the types that cannot be inferred.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.
continues on next page
142 Chapter 3. API Reference
pyspark Documentation, Release master
Table 32 – continued from previous pagetypeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
ShortType
class pyspark.sql.types.ShortTypeShort data type, i.e. a signed 16-bit integer.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
3.1. Spark SQL 143
pyspark Documentation, Release master
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
StringType
class pyspark.sql.types.StringTypeString data type.
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
StructField
class pyspark.sql.types.StructField(name, dataType, nullable=True, metadata=None)A field in StructType.
Parameters
name [str] name of the field.
dataType [DataType] DataType of the field.
nullable [bool] whether the field can be null (None) or not.
144 Chapter 3. API Reference
pyspark Documentation, Release master
metadata [dict] a dict from string to simple type that can be toInternald to JSON automatically
Examples
>>> (StructField("f1", StringType(), True)... == StructField("f1", StringType(), True))True>>> (StructField("f1", StringType(), True)... == StructField("f2", StringType(), True))False
Methods
fromInternal(obj) Converts an internal SQL object into a native Pythonobject.
fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(obj)Converts an internal SQL object into a native Python object.
classmethod fromJson(json)
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
typeName()
3.1. Spark SQL 145
pyspark Documentation, Release master
StructType
class pyspark.sql.types.StructType(fields=None)Struct type, consisting of a list of StructField.
This is the data type representing a Row.
Iterating a StructType will iterate over its StructFields. A contained StructField can be accessedby its name or position.
Examples
>>> struct1 = StructType([StructField("f1", StringType(), True)])>>> struct1["f1"]StructField(f1,StringType,true)>>> struct1[0]StructField(f1,StringType,true)
>>> struct1 = StructType([StructField("f1", StringType(), True)])>>> struct2 = StructType([StructField("f1", StringType(), True)])>>> struct1 == struct2True>>> struct1 = StructType([StructField("f1", StringType(), True)])>>> struct2 = StructType([StructField("f1", StringType(), True),... StructField("f2", IntegerType(), False)])>>> struct1 == struct2False
Methods
add(field[, data_type, nullable, metadata]) Construct a StructType by adding new elements to it,to define the schema.
fieldNames() Returns all field names in a list.fromInternal(obj) Converts an internal SQL object into a native Python
object.fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()
146 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
add(field, data_type=None, nullable=True, metadata=None)Construct a StructType by adding new elements to it, to define the schema. The method accepts either:
a) A single parameter which is a StructField object.
b) Between 2 and 4 parameters as (name, data_type, nullable (optional), metadata(optional). Thedata_type parameter may be either a String or a DataType object.
Parameters
field [str or StructField] Either the name of the field or a StructField object
data_type [DataType, optional] If present, the DataType of the StructField to create
nullable [bool, optional] Whether the field to add should be nullable (default True)
metadata [dict, optional] Any additional metadata (default None)
Returns
StructType
Examples
>>> struct1 = StructType().add("f1", StringType(), True).add("f2",→˓StringType(), True, None)>>> struct2 = StructType([StructField("f1", StringType(), True), \... StructField("f2", StringType(), True, None)])>>> struct1 == struct2True>>> struct1 = StructType().add(StructField("f1", StringType(), True))>>> struct2 = StructType([StructField("f1", StringType(), True)])>>> struct1 == struct2True>>> struct1 = StructType().add("f1", "string", True)>>> struct2 = StructType([StructField("f1", StringType(), True)])>>> struct1 == struct2True
fieldNames()Returns all field names in a list.
Examples
>>> struct = StructType([StructField("f1", StringType(), True)])>>> struct.fieldNames()['f1']
fromInternal(obj)Converts an internal SQL object into a native Python object.
classmethod fromJson(json)
json()
jsonValue()
3.1. Spark SQL 147
pyspark Documentation, Release master
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(obj)Converts a Python object into an internal SQL object.
classmethod typeName()
TimestampType
class pyspark.sql.types.TimestampTypeTimestamp (datetime.datetime) data type.
Methods
fromInternal(ts) Converts an internal SQL object into a native Pythonobject.
json()jsonValue()needConversion() Does this type needs conversion between Python ob-
ject and internal SQL object.simpleString()toInternal(dt) Converts a Python object into an internal SQL object.typeName()
Methods Documentation
fromInternal(ts)Converts an internal SQL object into a native Python object.
json()
jsonValue()
needConversion()Does this type needs conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
simpleString()
toInternal(dt)Converts a Python object into an internal SQL object.
classmethod typeName()
148 Chapter 3. API Reference
pyspark Documentation, Release master
3.1.7 Row
Row.asDict([recursive]) Return as a dict
pyspark.sql.Row.asDict
Row.asDict(recursive=False)Return as a dict
Parameters
recursive [bool, optional] turns the nested Rows to dict (default: False).
Notes
If a row contains duplicate field names, e.g., the rows of a join between two DataFrame that both have thefields of same names, one of the duplicate fields will be selected by asDict. __getitem__ will also returnone of the duplicate fields, however returned value might be different to asDict.
Examples
>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}True>>> row = Row(key=1, value=Row(name='a', age=2))>>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}True>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}True
3.1.8 Functions
abs(col) Computes the absolute value.acos(col) New in version 1.4.0.
acosh(col) Computes inverse hyperbolic cosine of the input col-umn.
add_months(start, months) Returns the date that is months months after startaggregate(col, zero, merge[, finish]) Applies a binary operator to an initial state and all ele-
ments in the array, and reduces this to a single state.approxCountDistinct(col[, rsd]) Deprecated since version 2.1.0.
approx_count_distinct(col[, rsd]) Aggregate function: returns a new Column for approx-imate distinct count of column col.
array(*cols) Creates a new array column.array_contains(col, value) Collection function: returns null if the array is null, true
if the array contains the given value, and false otherwise.array_distinct(col) Collection function: removes duplicate values from the
array.continues on next page
3.1. Spark SQL 149
pyspark Documentation, Release master
Table 39 – continued from previous pagearray_except(col1, col2) Collection function: returns an array of the elements in
col1 but not in col2, without duplicates.array_intersect(col1, col2) Collection function: returns an array of the elements in
the intersection of col1 and col2, without duplicates.array_join(col, delimiter[, null_replacement]) Concatenates the elements of column using the delim-
iter.array_max(col) Collection function: returns the maximum value of the
array.array_min(col) Collection function: returns the minimum value of the
array.array_position(col, value) Collection function: Locates the position of the first oc-
currence of the given value in the given array.array_remove(col, element) Collection function: Remove all elements that equal to
element from the given array.array_repeat(col, count) Collection function: creates an array containing a col-
umn repeated count times.array_sort(col) Collection function: sorts the input array in ascending
order.array_union(col1, col2) Collection function: returns an array of the elements in
the union of col1 and col2, without duplicates.arrays_overlap(a1, a2) Collection function: returns true if the arrays contain
any common non-null element; if not, returns null ifboth the arrays are non-empty and any of them containsa null element; returns false otherwise.
arrays_zip(*cols) Collection function: Returns a merged array of structs inwhich the N-th struct contains all N-th values of inputarrays.
asc(col) Returns a sort expression based on the ascending orderof the given column name.
asc_nulls_first(col) Returns a sort expression based on the ascending orderof the given column name, and null values return beforenon-null values.
asc_nulls_last(col) Returns a sort expression based on the ascending orderof the given column name, and null values appear afternon-null values.
ascii(col) Computes the numeric value of the first character of thestring column.
asin(col) New in version 1.3.0.
asinh(col) Computes inverse hyperbolic sine of the input column.assert_true(col[, errMsg]) Returns null if the input column is true; throws an ex-
ception with the provided error message otherwise.atan(col) New in version 1.4.0.
atanh(col) Computes inverse hyperbolic tangent of the input col-umn.
atan2(col1, col2) New in version 1.4.0.
avg(col) Aggregate function: returns the average of the values ina group.
continues on next page
150 Chapter 3. API Reference
pyspark Documentation, Release master
Table 39 – continued from previous pagebase64(col) Computes the BASE64 encoding of a binary column
and returns it as a string column.bin(col) Returns the string representation of the binary value of
the given column.bitwiseNOT(col) Computes bitwise not.broadcast(df) Marks a DataFrame as small enough for use in broadcast
joins.bround(col[, scale]) Round the given value to scale decimal places using
HALF_EVEN rounding mode if scale >= 0 or at inte-gral part when scale < 0.
bucket(numBuckets, col) Partition transform function: A transform for any typethat partitions by a hash of the input column.
cbrt(col) Computes the cube-root of the given value.ceil(col) Computes the ceiling of the given value.coalesce(*cols) Returns the first column that is not null.col(col) Returns a Column based on the given column name.’collect_list(col) Aggregate function: returns a list of objects with dupli-
cates.collect_set(col) Aggregate function: returns a set of objects with dupli-
cate elements eliminated.column(col) Returns a Column based on the given column name.’concat(*cols) Concatenates multiple input columns together into a sin-
gle column.concat_ws(sep, *cols) Concatenates multiple input string columns together
into a single string column, using the given separator.conv(col, fromBase, toBase) Convert a number in a string column from one base to
another.corr(col1, col2) Returns a new Column for the Pearson Correlation Co-
efficient for col1 and col2.cos(col) New in version 1.4.0.
cosh(col) New in version 1.4.0.
count(col) Aggregate function: returns the number of items in agroup.
countDistinct(col, *cols) Returns a new Column for distinct count of col orcols.
covar_pop(col1, col2) Returns a new Column for the population covarianceof col1 and col2.
covar_samp(col1, col2) Returns a new Column for the sample covariance ofcol1 and col2.
crc32(col) Calculates the cyclic redundancy check value (CRC32)of a binary column and returns the value as a bigint.
create_map(*cols) Creates a new map column.cume_dist() Window function: returns the cumulative distribution of
values within a window partition, i.e.current_date() Returns the current date at the start of query evaluation
as a DateType column.current_timestamp() Returns the current timestamp at the start of query eval-
uation as a TimestampType column.date_add(start, days) Returns the date that is days days after start
continues on next page
3.1. Spark SQL 151
pyspark Documentation, Release master
Table 39 – continued from previous pagedate_format(date, format) Converts a date/timestamp/string to a value of string in
the format specified by the date format given by the sec-ond argument.
date_sub(start, days) Returns the date that is days days before startdate_trunc(format, timestamp) Returns timestamp truncated to the unit specified by the
format.datediff(end, start) Returns the number of days from start to end.dayofmonth(col) Extract the day of the month of a given date as integer.dayofweek(col) Extract the day of the week of a given date as integer.dayofyear(col) Extract the day of the year of a given date as integer.days(col) Partition transform function: A transform for times-
tamps and dates to partition data into days.decode(col, charset) Computes the first argument into a string from a bi-
nary using the provided character set (one of ‘US-ASCII’, ‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).
degrees(col) Converts an angle measured in radians to an approxi-mately equivalent angle measured in degrees.
dense_rank() Window function: returns the rank of rows within a win-dow partition, without any gaps.
desc(col) Returns a sort expression based on the descending orderof the given column name.
desc_nulls_first(col) Returns a sort expression based on the descending orderof the given column name, and null values appear beforenon-null values.
desc_nulls_last(col) Returns a sort expression based on the descending orderof the given column name, and null values appear afternon-null values.
element_at(col, extraction) Collection function: Returns element of array at givenindex in extraction if col is array.
encode(col, charset) Computes the first argument into a binary from astring using the provided character set (one of ‘US-ASCII’, ‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).
exists(col, f) Returns whether a predicate holds for one or more ele-ments in the array.
exp(col) Computes the exponential of the given value.explode(col) Returns a new row for each element in the given array
or map.explode_outer(col) Returns a new row for each element in the given array
or map.expm1(col) Computes the exponential of the given value minus one.expr(str) Parses the expression string into the column that it rep-
resentsfactorial(col) Computes the factorial of the given value.filter(col, f) Returns an array of elements for which a predicate holds
in a given array.first(col[, ignorenulls]) Aggregate function: returns the first value in a group.flatten(col) Collection function: creates a single array from an array
of arrays.floor(col) Computes the floor of the given value.
continues on next page
152 Chapter 3. API Reference
pyspark Documentation, Release master
Table 39 – continued from previous pageforall(col, f) Returns whether a predicate holds for every element in
the array.format_number(col, d) Formats the number X to a format like ‘#,–#,–#.–’,
rounded to d decimal places with HALF_EVEN roundmode, and returns the result as a string.
format_string(format, *cols) Formats the arguments in printf-style and returns the re-sult as a string column.
from_csv(col, schema[, options]) Parses a column containing a CSV string to a row withthe specified schema.
from_json(col, schema[, options]) Parses a column containing a JSON string intoa MapType with StringType as keys type,StructType or ArrayType with the specifiedschema.
from_unixtime(timestamp[, format]) Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the times-tamp of that moment in the current system time zone inthe given format.
from_utc_timestamp(timestamp, tz) This is a common function for databases supportingTIMESTAMP WITHOUT TIMEZONE.
get_json_object(col, path) Extracts json object from a json string based on jsonpath specified, and returns json string of the extractedjson object.
greatest(*cols) Returns the greatest value of the list of column names,skipping null values.
grouping(col) Aggregate function: indicates whether a specified col-umn in a GROUP BY list is aggregated or not, returns 1for aggregated or 0 for not aggregated in the result set.
grouping_id(*cols) Aggregate function: returns the level of grouping,equals to
hash(*cols) Calculates the hash code of given columns, and returnsthe result as an int column.
hex(col) Computes hex value of the given column, whichcould be pyspark.sql.types.StringType,pyspark.sql.types.BinaryType, pyspark.sql.types.IntegerType or pyspark.sql.types.LongType.
hour(col) Extract the hours of a given date as integer.hours(col) Partition transform function: A transform for times-
tamps to partition data into hours.hypot(col1, col2) Computes sqrt(a^2 + b^2) without intermediate
overflow or underflow.initcap(col) Translate the first letter of each word to upper case in
the sentence.input_file_name() Creates a string column for the file name of the current
Spark task.instr(str, substr) Locate the position of the first occurrence of substr col-
umn in the given string.isnan(col) An expression that returns true iff the column is NaN.isnull(col) An expression that returns true iff the column is null.json_tuple(col, *fields) Creates a new row for a json column according to the
given field names.continues on next page
3.1. Spark SQL 153
pyspark Documentation, Release master
Table 39 – continued from previous pagekurtosis(col) Aggregate function: returns the kurtosis of the values in
a group.lag(col[, offset, default]) Window function: returns the value that is offset rows
before the current row, and defaultValue if there is lessthan offset rows before the current row.
last(col[, ignorenulls]) Aggregate function: returns the last value in a group.last_day(date) Returns the last day of the month which the given date
belongs to.lead(col[, offset, default]) Window function: returns the value that is offset rows
after the current row, and defaultValue if there is lessthan offset rows after the current row.
least(*cols) Returns the least value of the list of column names, skip-ping null values.
length(col) Computes the character length of string data or numberof bytes of binary data.
levenshtein(left, right) Computes the Levenshtein distance of the two givenstrings.
lit(col) Creates a Column of literal value.locate(substr, str[, pos]) Locate the position of the first occurrence of substr in a
string column, after position pos.log(arg1[, arg2]) Returns the first argument-based logarithm of the sec-
ond argument.log10(col) Computes the logarithm of the given value in Base 10.log1p(col) Computes the natural logarithm of the given value plus
one.log2(col) Returns the base-2 logarithm of the argument.lower(col) Converts a string expression to lower case.lpad(col, len, pad) Left-pad the string column to width len with pad.ltrim(col) Trim the spaces from left end for the specified string
value.map_concat(*cols) Returns the union of all the given maps.map_entries(col) Collection function: Returns an unordered array of all
entries in the given map.map_filter(col, f) Returns a map whose key-value pairs satisfy a predicate.map_from_arrays(col1, col2) Creates a new map from two arrays.map_from_entries(col) Collection function: Returns a map created from the
given array of entries.map_keys(col) Collection function: Returns an unordered array con-
taining the keys of the map.map_values(col) Collection function: Returns an unordered array con-
taining the values of the map.map_zip_with(col1, col2, f) Merge two given maps, key-wise into a single map using
a function.max(col) Aggregate function: returns the maximum value of the
expression in a group.md5(col) Calculates the MD5 digest and returns the value as a 32
character hex string.mean(col) Aggregate function: returns the average of the values in
a group.min(col) Aggregate function: returns the minimum value of the
expression in a group.continues on next page
154 Chapter 3. API Reference
pyspark Documentation, Release master
Table 39 – continued from previous pageminute(col) Extract the minutes of a given date as integer.monotonically_increasing_id() A column that generates monotonically increasing 64-
bit integers.month(col) Extract the month of a given date as integer.months(col) Partition transform function: A transform for times-
tamps and dates to partition data into months.months_between(date1, date2[, roundOff]) Returns number of months between dates date1 and
date2.nanvl(col1, col2) Returns col1 if it is not NaN, or col2 if col1 is NaN.next_day(date, dayOfWeek) Returns the first date which is later than the value of the
date column.nth_value(col, offset[, ignoreNulls]) Window function: returns the value that is the offsetth
row of the window frame (counting from 1), and null ifthe size of window frame is less than offset rows.
ntile(n) Window function: returns the ntile group id (from 1 ton inclusive) in an ordered window partition.
overlay(src, replace, pos[, len]) Overlay the specified portion of src with replace, start-ing from byte position pos of src and proceeding for lenbytes.
pandas_udf([f, returnType, functionType]) Creates a pandas user defined function (a.k.a.percent_rank() Window function: returns the relative rank (i.e.percentile_approx(col, percentage[, accuracy]) Returns the approximate percentile of the numeric col-
umn col which is the smallest value in the ordered colvalues (sorted from least to greatest) such that no morethan percentage of col values is less than the value orequal to that value.
posexplode(col) Returns a new row for each element with position in thegiven array or map.
posexplode_outer(col) Returns a new row for each element with position in thegiven array or map.
pow(col1, col2) Returns the value of the first argument raised to thepower of the second argument.
quarter(col) Extract the quarter of a given date as integer.radians(col) Converts an angle measured in degrees to an approxi-
mately equivalent angle measured in radians.raise_error(errMsg) Throws an exception with the provided error message.rand([seed]) Generates a random column with independent and iden-
tically distributed (i.i.d.) samples uniformly distributedin [0.0, 1.0).
randn([seed]) Generates a column with independent and identicallydistributed (i.i.d.) samples from the standard normaldistribution.
rank() Window function: returns the rank of rows within a win-dow partition.
regexp_extract(str, pattern, idx) Extract a specific group matched by a Java regex, fromthe specified string column.
regexp_replace(str, pattern, replacement) Replace all substrings of the specified string value thatmatch regexp with rep.
repeat(col, n) Repeats a string column n times, and returns it as a newstring column.
continues on next page
3.1. Spark SQL 155
pyspark Documentation, Release master
Table 39 – continued from previous pagereverse(col) Collection function: returns a reversed string or an array
with reverse order of elements.rint(col) Returns the double value that is closest in value to the
argument and is equal to a mathematical integer.round(col[, scale]) Round the given value to scale decimal places using
HALF_UP rounding mode if scale >= 0 or at integralpart when scale < 0.
row_number() Window function: returns a sequential number startingat 1 within a window partition.
rpad(col, len, pad) Right-pad the string column to width len with pad.rtrim(col) Trim the spaces from right end for the specified string
value.schema_of_csv(csv[, options]) Parses a CSV string and infers its schema in DDL for-
mat.schema_of_json(json[, options]) Parses a JSON string and infers its schema in DDL for-
mat.second(col) Extract the seconds of a given date as integer.sequence(start, stop[, step]) Generate a sequence of integers from start to stop, in-
crementing by step.sha1(col) Returns the hex string result of SHA-1.sha2(col, numBits) Returns the hex string result of SHA-2 family of hash
functions (SHA-224, SHA-256, SHA-384, and SHA-512).
shiftLeft(col, numBits) Shift the given value numBits left.shiftRight(col, numBits) (Signed) shift the given value numBits right.shiftRightUnsigned(col, numBits) Unsigned shift the given value numBits right.shuffle(col) Collection function: Generates a random permutation of
the given array.signum(col) Computes the signum of the given value.sin(col) New in version 1.4.0.
sinh(col) New in version 1.4.0.
size(col) Collection function: returns the length of the array ormap stored in the column.
skewness(col) Aggregate function: returns the skewness of the valuesin a group.
slice(x, start, length) Collection function: returns an array containing all theelements in x from index start (array indices start at 1,or from the end if start is negative) with the specifiedlength.
sort_array(col[, asc]) Collection function: sorts the input array in ascendingor descending order according to the natural ordering ofthe array elements.
soundex(col) Returns the SoundEx encoding for a stringspark_partition_id() A column for partition ID.split(str, pattern[, limit]) Splits str around matches of the given pattern.sqrt(col) Computes the square root of the specified float value.stddev(col) Aggregate function: alias for stddev_samp.stddev_pop(col) Aggregate function: returns population standard devia-
tion of the expression in a group.continues on next page
156 Chapter 3. API Reference
pyspark Documentation, Release master
Table 39 – continued from previous pagestddev_samp(col) Aggregate function: returns the unbiased sample stan-
dard deviation of the expression in a group.struct(*cols) Creates a new struct column.substring(str, pos, len) Substring starts at pos and is of length len when str is
String type or returns the slice of byte array that starts atpos in byte and is of length len when str is Binary type.
substring_index(str, delim, count) Returns the substring from string str before count occur-rences of the delimiter delim.
sum(col) Aggregate function: returns the sum of all values in theexpression.
sumDistinct(col) Aggregate function: returns the sum of distinct valuesin the expression.
tan(col) New in version 1.4.0.
tanh(col) New in version 1.4.0.
timestamp_seconds(col) New in version 3.1.0.
toDegrees(col) Deprecated since version 2.1.0.
toRadians(col) Deprecated since version 2.1.0.
to_csv(col[, options]) Converts a column containing a StructType into aCSV string.
to_date(col[, format]) Converts a Column into pyspark.sql.types.DateType using the optionally specified format.
to_json(col[, options]) Converts a column containing a StructType,ArrayType or a MapType into a JSON string.
to_timestamp(col[, format]) Converts a Column into pyspark.sql.types.TimestampType using the optionally specified for-mat.
to_utc_timestamp(timestamp, tz) This is a common function for databases supportingTIMESTAMP WITHOUT TIMEZONE.
transform(col, f) Returns an array of elements after applying a transfor-mation to each element in the input array.
transform_keys(col, f) Applies a function to every key-value pair in a map andreturns a map with the results of those applications asthe new keys for the pairs.
transform_values(col, f) Applies a function to every key-value pair in a map andreturns a map with the results of those applications asthe new values for the pairs.
translate(srcCol, matching, replace) A function translate any character in the srcCol by acharacter in matching.
trim(col) Trim the spaces from both ends for the specified stringcolumn.
trunc(date, format) Returns date truncated to the unit specified by the for-mat.
udf([f, returnType]) Creates a user defined function (UDF).unbase64(col) Decodes a BASE64 encoded string column and returns
it as a binary column.unhex(col) Inverse of hex.
continues on next page
3.1. Spark SQL 157
pyspark Documentation, Release master
Table 39 – continued from previous pageunix_timestamp([timestamp, format]) Convert time string with given pattern (‘yyyy-MM-dd
HH:mm:ss’, by default) to Unix time stamp (in sec-onds), using the default timezone and the default locale,return null if fail.
upper(col) Converts a string expression to upper case.var_pop(col) Aggregate function: returns the population variance of
the values in a group.var_samp(col) Aggregate function: returns the unbiased sample vari-
ance of the values in a group.variance(col) Aggregate function: alias for var_sampweekofyear(col) Extract the week number of a given date as integer.when(condition, value) Evaluates a list of conditions and returns one of multiple
possible result expressions.window(timeColumn, windowDuration[, . . . ]) Bucketize rows into one or more time windows given a
timestamp specifying column.xxhash64(*cols) Calculates the hash code of given columns using the 64-
bit variant of the xxHash algorithm, and returns the re-sult as a long column.
year(col) Extract the year of a given date as integer.years(col) Partition transform function: A transform for times-
tamps and dates to partition data into years.zip_with(col1, col2, f) Merge two given arrays, element-wise, into a single ar-
ray using a function.
pyspark.sql.functions.abs
pyspark.sql.functions.abs(col)Computes the absolute value.
New in version 1.3.
pyspark.sql.functions.acos
pyspark.sql.functions.acos(col)New in version 1.4.0.
Returns
Column inverse cosine of col, as if computed by java.lang.Math.acos()
pyspark.sql.functions.acosh
pyspark.sql.functions.acosh(col)Computes inverse hyperbolic cosine of the input column.
New in version 3.1.0.
Returns
Column
158 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.add_months
pyspark.sql.functions.add_months(start, months)Returns the date that is months months after start
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(add_months(df.dt, 1).alias('next_month')).collect()[Row(next_month=datetime.date(2015, 5, 8))]
pyspark.sql.functions.aggregate
pyspark.sql.functions.aggregate(col, zero, merge, finish=None)Applies a binary operator to an initial state and all elements in the array, and reduces this to a single state. Thefinal state is converted into the final result by applying a finish function.
Both functions can use methods of pyspark.sql.Column, functions defined in pyspark.sql.functions and Scala UserDefinedFunctions. Python UserDefinedFunctions are not supported(SPARK-27052).
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
zero [Column or str] initial value. Name of column or expression
merge [function] a binary function (acc: Column, x: Column) -> Column... re-turning expression of the same type as zero
finish [function] an optional unary function (x: Column) -> Column: ... used to con-vert accumulated value.
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values→˓"))>>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).→˓show()+----+| sum|+----+|42.0|+----+
3.1. Spark SQL 159
pyspark Documentation, Release master
>>> def merge(acc, x):... count = acc.count + 1... sum = acc.sum + x... return struct(count.alias("count"), sum.alias("sum"))>>> df.select(... aggregate(... "values",... struct(lit(0).alias("count"), lit(0.0).alias("sum")),... merge,... lambda acc: acc.sum / acc.count,... ).alias("mean")... ).show()+----+|mean|+----+| 8.4|+----+
pyspark.sql.functions.approxCountDistinct
pyspark.sql.functions.approxCountDistinct(col, rsd=None)Deprecated since version 2.1.0: Use approx_count_distinct() instead.
New in version 1.3.
pyspark.sql.functions.approx_count_distinct
pyspark.sql.functions.approx_count_distinct(col, rsd=None)Aggregate function: returns a new Column for approximate distinct count of column col.
New in version 2.1.0.
Parameters
col [Column or str]
rsd [float, optional] maximum relative standard deviation allowed (default = 0.05). For rsd <0.01, it is more efficient to use countDistinct()
Examples
>>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect()[Row(distinct_ages=2)]
160 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.array
pyspark.sql.functions.array(*cols)Creates a new array column.
New in version 1.4.0.
Parameters
cols [Column or str] column names or Columns that have the same data type.
Examples
>>> df.select(array('age', 'age').alias("arr")).collect()[Row(arr=[2, 2]), Row(arr=[5, 5])]>>> df.select(array([df.age, df.age]).alias("arr")).collect()[Row(arr=[2, 2]), Row(arr=[5, 5])]
pyspark.sql.functions.array_contains
pyspark.sql.functions.array_contains(col, value)Collection function: returns null if the array is null, true if the array contains the given value, and false otherwise.
New in version 1.5.0.
Parameters
col [Column or str] name of column containing array
value : value or column to check for in array
Examples
>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])>>> df.select(array_contains(df.data, "a")).collect()[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]>>> df.select(array_contains(df.data, lit("a"))).collect()[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
pyspark.sql.functions.array_distinct
pyspark.sql.functions.array_distinct(col)Collection function: removes duplicate values from the array.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
3.1. Spark SQL 161
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])>>> df.select(array_distinct(df.data)).collect()[Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
pyspark.sql.functions.array_except
pyspark.sql.functions.array_except(col1, col2)Collection function: returns an array of the elements in col1 but not in col2, without duplicates.
New in version 2.4.0.
Parameters
col1 [Column or str] name of column containing array
col2 [Column or str] name of column containing array
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])>>> df.select(array_except(df.c1, df.c2)).collect()[Row(array_except(c1, c2)=['b'])]
pyspark.sql.functions.array_intersect
pyspark.sql.functions.array_intersect(col1, col2)Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates.
New in version 2.4.0.
Parameters
col1 [Column or str] name of column containing array
col2 [Column or str] name of column containing array
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])>>> df.select(array_intersect(df.c1, df.c2)).collect()[Row(array_intersect(c1, c2)=['a', 'c'])]
162 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.array_join
pyspark.sql.functions.array_join(col, delimiter, null_replacement=None)Concatenates the elements of column using the delimiter. Null values are replaced with null_replacement if set,otherwise they are ignored.
New in version 2.4.0.
Examples
>>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])>>> df.select(array_join(df.data, ",").alias("joined")).collect()[Row(joined='a,b,c'), Row(joined='a')]>>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()[Row(joined='a,b,c'), Row(joined='a,NULL')]
pyspark.sql.functions.array_max
pyspark.sql.functions.array_max(col)Collection function: returns the maximum value of the array.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])>>> df.select(array_max(df.data).alias('max')).collect()[Row(max=3), Row(max=10)]
pyspark.sql.functions.array_min
pyspark.sql.functions.array_min(col)Collection function: returns the minimum value of the array.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
3.1. Spark SQL 163
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])>>> df.select(array_min(df.data).alias('min')).collect()[Row(min=1), Row(min=-1)]
pyspark.sql.functions.array_position
pyspark.sql.functions.array_position(col, value)Collection function: Locates the position of the first occurrence of the given value in the given array. Returnsnull if either of the arguments are null.
New in version 2.4.0.
Notes
The position is not zero based, but 1 based index. Returns 0 if the given value could not be found in the array.
Examples
>>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])>>> df.select(array_position(df.data, "a")).collect()[Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
pyspark.sql.functions.array_remove
pyspark.sql.functions.array_remove(col, element)Collection function: Remove all elements that equal to element from the given array.
New in version 2.4.0.
Parameters
col [Column or str] name of column containing array
element : element to be removed from the array
Examples
>>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])>>> df.select(array_remove(df.data, 1)).collect()[Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])]
164 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.array_repeat
pyspark.sql.functions.array_repeat(col, count)Collection function: creates an array containing a column repeated count times.
New in version 2.4.0.
Examples
>>> df = spark.createDataFrame([('ab',)], ['data'])>>> df.select(array_repeat(df.data, 3).alias('r')).collect()[Row(r=['ab', 'ab', 'ab'])]
pyspark.sql.functions.array_sort
pyspark.sql.functions.array_sort(col)Collection function: sorts the input array in ascending order. The elements of the input array must be orderable.Null elements will be placed at the end of the returned array.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])>>> df.select(array_sort(df.data).alias('r')).collect()[Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
pyspark.sql.functions.array_union
pyspark.sql.functions.array_union(col1, col2)Collection function: returns an array of the elements in the union of col1 and col2, without duplicates.
New in version 2.4.0.
Parameters
col1 [Column or str] name of column containing array
col2 [Column or str] name of column containing array
3.1. Spark SQL 165
pyspark Documentation, Release master
Examples
>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])>>> df.select(array_union(df.c1, df.c2)).collect()[Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])]
pyspark.sql.functions.arrays_overlap
pyspark.sql.functions.arrays_overlap(a1, a2)Collection function: returns true if the arrays contain any common non-null element; if not, returns null if boththe arrays are non-empty and any of them contains a null element; returns false otherwise.
New in version 2.4.0.
Examples
>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], [→˓'x', 'y'])>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()[Row(overlap=True), Row(overlap=False)]
pyspark.sql.functions.arrays_zip
pyspark.sql.functions.arrays_zip(*cols)Collection function: Returns a merged array of structs in which the N-th struct contains all N-th values of inputarrays.
New in version 2.4.0.
Parameters
cols [Column or str] columns of arrays to be merged.
Examples
>>> from pyspark.sql.functions import arrays_zip>>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])>>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()[Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3,→˓vals2=4)])]
166 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.asc
pyspark.sql.functions.asc(col)Returns a sort expression based on the ascending order of the given column name.
New in version 1.3.
pyspark.sql.functions.asc_nulls_first
pyspark.sql.functions.asc_nulls_first(col)Returns a sort expression based on the ascending order of the given column name, and null values return beforenon-null values.
New in version 2.4.
pyspark.sql.functions.asc_nulls_last
pyspark.sql.functions.asc_nulls_last(col)Returns a sort expression based on the ascending order of the given column name, and null values appear afternon-null values.
New in version 2.4.
pyspark.sql.functions.ascii
pyspark.sql.functions.ascii(col)Computes the numeric value of the first character of the string column.
New in version 1.5.
pyspark.sql.functions.asin
pyspark.sql.functions.asin(col)New in version 1.3.0.
Returns
Column inverse sine of col, as if computed by java.lang.Math.asin()
pyspark.sql.functions.asinh
pyspark.sql.functions.asinh(col)Computes inverse hyperbolic sine of the input column.
New in version 3.1.0.
Returns
Column
3.1. Spark SQL 167
pyspark Documentation, Release master
pyspark.sql.functions.assert_true
pyspark.sql.functions.assert_true(col, errMsg=None)Returns null if the input column is true; throws an exception with the provided error message otherwise.
New in version 3.1.0.
Examples
>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])>>> df.select(assert_true(df.a < df.b).alias('r')).collect()[Row(r=None)]>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])>>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect()[Row(r=None)]>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])>>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect()[Row(r=None)]
pyspark.sql.functions.atan
pyspark.sql.functions.atan(col)New in version 1.4.0.
Returns
Column inverse tangent of col, as if computed by java.lang.Math.atan()
pyspark.sql.functions.atanh
pyspark.sql.functions.atanh(col)Computes inverse hyperbolic tangent of the input column.
New in version 3.1.0.
Returns
Column
pyspark.sql.functions.atan2
pyspark.sql.functions.atan2(col1, col2)New in version 1.4.0.
Parameters
col1 [str, Column or float] coordinate on y-axis
col2 [str, Column or float] coordinate on x-axis
Returns
Column the theta component of the point (r, theta) in polar coordinates that corresponds to thepoint (x, y) in Cartesian coordinates, as if computed by java.lang.Math.atan2()
168 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.avg
pyspark.sql.functions.avg(col)Aggregate function: returns the average of the values in a group.
New in version 1.3.
pyspark.sql.functions.base64
pyspark.sql.functions.base64(col)Computes the BASE64 encoding of a binary column and returns it as a string column.
New in version 1.5.
pyspark.sql.functions.bin
pyspark.sql.functions.bin(col)Returns the string representation of the binary value of the given column.
New in version 1.5.0.
Examples
>>> df.select(bin(df.age).alias('c')).collect()[Row(c='10'), Row(c='101')]
pyspark.sql.functions.bitwiseNOT
pyspark.sql.functions.bitwiseNOT(col)Computes bitwise not.
New in version 1.4.
pyspark.sql.functions.broadcast
pyspark.sql.functions.broadcast(df)Marks a DataFrame as small enough for use in broadcast joins.
New in version 1.6.
pyspark.sql.functions.bround
pyspark.sql.functions.bround(col, scale=0)Round the given value to scale decimal places using HALF_EVEN rounding mode if scale >= 0 or at integralpart when scale < 0.
New in version 2.0.0.
3.1. Spark SQL 169
pyspark Documentation, Release master
Examples
>>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).→˓collect()[Row(r=2.0)]
pyspark.sql.functions.bucket
pyspark.sql.functions.bucket(numBuckets, col)Partition transform function: A transform for any type that partitions by a hash of the input column.
New in version 3.1.0.
Notes
This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.
Examples
>>> df.writeTo("catalog.db.table").partitionedBy(... bucket(42, "ts")... ).createOrReplace()
pyspark.sql.functions.cbrt
pyspark.sql.functions.cbrt(col)Computes the cube-root of the given value.
New in version 1.4.
pyspark.sql.functions.ceil
pyspark.sql.functions.ceil(col)Computes the ceiling of the given value.
New in version 1.4.
pyspark.sql.functions.coalesce
pyspark.sql.functions.coalesce(*cols)Returns the first column that is not null.
New in version 1.4.0.
170 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b"))>>> cDf.show()+----+----+| a| b|+----+----+|null|null|| 1|null||null| 2|+----+----+
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()+--------------+|coalesce(a, b)|+--------------+| null|| 1|| 2|+--------------+
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()+----+----+----------------+| a| b|coalesce(a, 0.0)|+----+----+----------------+|null|null| 0.0|| 1|null| 1.0||null| 2| 0.0|+----+----+----------------+
pyspark.sql.functions.col
pyspark.sql.functions.col(col)Returns a Column based on the given column name.’
New in version 1.3.
pyspark.sql.functions.collect_list
pyspark.sql.functions.collect_list(col)Aggregate function: returns a list of objects with duplicates.
New in version 1.6.0.
Notes
The function is non-deterministic because the order of collected results depends on the order of the rows whichmay be non-deterministic after a shuffle.
3.1. Spark SQL 171
pyspark Documentation, Release master
Examples
>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))>>> df2.agg(collect_list('age')).collect()[Row(collect_list(age)=[2, 5, 5])]
pyspark.sql.functions.collect_set
pyspark.sql.functions.collect_set(col)Aggregate function: returns a set of objects with duplicate elements eliminated.
New in version 1.6.0.
Notes
The function is non-deterministic because the order of collected results depends on the order of the rows whichmay be non-deterministic after a shuffle.
Examples
>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))>>> df2.agg(collect_set('age')).collect()[Row(collect_set(age)=[5, 2])]
pyspark.sql.functions.column
pyspark.sql.functions.column(col)Returns a Column based on the given column name.’
New in version 1.3.
pyspark.sql.functions.concat
pyspark.sql.functions.concat(*cols)Concatenates multiple input columns together into a single column. The function works with strings, binary andcompatible array columns.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])>>> df.select(concat(df.s, df.d).alias('s')).collect()[Row(s='abcd123')]
>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a',→˓ 'b', 'c'])>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
172 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.concat_ws
pyspark.sql.functions.concat_ws(sep, *cols)Concatenates multiple input string columns together into a single string column, using the given separator.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])>>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()[Row(s='abcd-123')]
pyspark.sql.functions.conv
pyspark.sql.functions.conv(col, fromBase, toBase)Convert a number in a string column from one base to another.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([("010101",)], ['n'])>>> df.select(conv(df.n, 2, 16).alias('hex')).collect()[Row(hex='15')]
pyspark.sql.functions.corr
pyspark.sql.functions.corr(col1, col2)Returns a new Column for the Pearson Correlation Coefficient for col1 and col2.
New in version 1.6.0.
Examples
>>> a = range(20)>>> b = [2 * x for x in range(20)]>>> df = spark.createDataFrame(zip(a, b), ["a", "b"])>>> df.agg(corr("a", "b").alias('c')).collect()[Row(c=1.0)]
3.1. Spark SQL 173
pyspark Documentation, Release master
pyspark.sql.functions.cos
pyspark.sql.functions.cos(col)New in version 1.4.0.
Parameters
col [Column or str] angle in radians
Returns
Column cosine of the angle, as if computed by java.lang.Math.cos().
pyspark.sql.functions.cosh
pyspark.sql.functions.cosh(col)New in version 1.4.0.
Parameters
col [Column or str] hyperbolic angle
Returns
Column hyperbolic cosine of the angle, as if computed by java.lang.Math.cosh()
pyspark.sql.functions.count
pyspark.sql.functions.count(col)Aggregate function: returns the number of items in a group.
New in version 1.3.
pyspark.sql.functions.countDistinct
pyspark.sql.functions.countDistinct(col, *cols)Returns a new Column for distinct count of col or cols.
New in version 1.3.0.
Examples
>>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()[Row(c=2)]
>>> df.agg(countDistinct("age", "name").alias('c')).collect()[Row(c=2)]
174 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.covar_pop
pyspark.sql.functions.covar_pop(col1, col2)Returns a new Column for the population covariance of col1 and col2.
New in version 2.0.0.
Examples
>>> a = [1] * 10>>> b = [1] * 10>>> df = spark.createDataFrame(zip(a, b), ["a", "b"])>>> df.agg(covar_pop("a", "b").alias('c')).collect()[Row(c=0.0)]
pyspark.sql.functions.covar_samp
pyspark.sql.functions.covar_samp(col1, col2)Returns a new Column for the sample covariance of col1 and col2.
New in version 2.0.0.
Examples
>>> a = [1] * 10>>> b = [1] * 10>>> df = spark.createDataFrame(zip(a, b), ["a", "b"])>>> df.agg(covar_samp("a", "b").alias('c')).collect()[Row(c=0.0)]
pyspark.sql.functions.crc32
pyspark.sql.functions.crc32(col)Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value as a bigint.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).→˓collect()[Row(crc32=2743272264)]
3.1. Spark SQL 175
pyspark Documentation, Release master
pyspark.sql.functions.create_map
pyspark.sql.functions.create_map(*cols)Creates a new map column.
New in version 2.0.0.
Parameters
cols [Column or str] column names or Columns that are grouped as key-value pairs, e.g.(key1, value1, key2, value2, . . . ).
Examples
>>> df.select(create_map('name', 'age').alias("map")).collect()[Row(map={'Alice': 2}), Row(map={'Bob': 5})]>>> df.select(create_map([df.name, df.age]).alias("map")).collect()[Row(map={'Alice': 2}), Row(map={'Bob': 5})]
pyspark.sql.functions.cume_dist
pyspark.sql.functions.cume_dist()Window function: returns the cumulative distribution of values within a window partition, i.e. the fraction ofrows that are below the current row.
New in version 1.6.
pyspark.sql.functions.current_date
pyspark.sql.functions.current_date()Returns the current date at the start of query evaluation as a DateType column. All calls of current_date withinthe same query return the same value.
New in version 1.5.
pyspark.sql.functions.current_timestamp
pyspark.sql.functions.current_timestamp()Returns the current timestamp at the start of query evaluation as a TimestampType column. All calls ofcurrent_timestamp within the same query return the same value.
pyspark.sql.functions.date_add
pyspark.sql.functions.date_add(start, days)Returns the date that is days days after start
New in version 1.5.0.
176 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(date_add(df.dt, 1).alias('next_date')).collect()[Row(next_date=datetime.date(2015, 4, 9))]
pyspark.sql.functions.date_format
pyspark.sql.functions.date_format(date, format)Converts a date/timestamp/string to a value of string in the format specified by the date format given by thesecond argument.
A pattern could be for instance dd.MM.yyyy and could return a string like ‘18.03.1993’. All pattern letters ofdatetime pattern. can be used.
New in version 1.5.0.
Notes
Whenever possible, use specialized functions like year.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect()[Row(date='04/08/2015')]
pyspark.sql.functions.date_sub
pyspark.sql.functions.date_sub(start, days)Returns the date that is days days before start
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect()[Row(prev_date=datetime.date(2015, 4, 7))]
pyspark.sql.functions.date_trunc
pyspark.sql.functions.date_trunc(format, timestamp)Returns timestamp truncated to the unit specified by the format.
New in version 2.3.0.
Parameters
format [str] ‘year’, ‘yyyy’, ‘yy’, ‘month’, ‘mon’, ‘mm’, ‘day’, ‘dd’, ‘hour’, ‘minute’, ‘second’,‘week’, ‘quarter’
3.1. Spark SQL 177
pyspark Documentation, Release master
timestamp [Column or str]
Examples
>>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])>>> df.select(date_trunc('year', df.t).alias('year')).collect()[Row(year=datetime.datetime(1997, 1, 1, 0, 0))]>>> df.select(date_trunc('mon', df.t).alias('month')).collect()[Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
pyspark.sql.functions.datediff
pyspark.sql.functions.datediff(end, start)Returns the number of days from start to end.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])>>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()[Row(diff=32)]
pyspark.sql.functions.dayofmonth
pyspark.sql.functions.dayofmonth(col)Extract the day of the month of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(dayofmonth('dt').alias('day')).collect()[Row(day=8)]
pyspark.sql.functions.dayofweek
pyspark.sql.functions.dayofweek(col)Extract the day of the week of a given date as integer.
New in version 2.3.0.
178 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(dayofweek('dt').alias('day')).collect()[Row(day=4)]
pyspark.sql.functions.dayofyear
pyspark.sql.functions.dayofyear(col)Extract the day of the year of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(dayofyear('dt').alias('day')).collect()[Row(day=98)]
pyspark.sql.functions.days
pyspark.sql.functions.days(col)Partition transform function: A transform for timestamps and dates to partition data into days.
New in version 3.1.0.
Notes
This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.
Examples
>>> df.writeTo("catalog.db.table").partitionedBy(... days("ts")... ).createOrReplace()
pyspark.sql.functions.decode
pyspark.sql.functions.decode(col, charset)Computes the first argument into a string from a binary using the provided character set (one of ‘US-ASCII’,‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).
New in version 1.5.
3.1. Spark SQL 179
pyspark Documentation, Release master
pyspark.sql.functions.degrees
pyspark.sql.functions.degrees(col)Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
New in version 2.1.0.
Parameters
col [Column or str] angle in radians
Returns
Column angle in degrees, as if computed by java.lang.Math.toDegrees()
pyspark.sql.functions.dense_rank
pyspark.sql.functions.dense_rank()Window function: returns the rank of rows within a window partition, without any gaps.
The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking sequence when thereare ties. That is, if you were ranking a competition using dense_rank and had three people tie for second place,you would say that all three were in second place and that the next person came in third. Rank would give mesequential numbers, making the person that came in third place (after the ties) would register as coming in fifth.
This is equivalent to the DENSE_RANK function in SQL.
New in version 1.6.
pyspark.sql.functions.desc
pyspark.sql.functions.desc(col)Returns a sort expression based on the descending order of the given column name.
New in version 1.3.
pyspark.sql.functions.desc_nulls_first
pyspark.sql.functions.desc_nulls_first(col)Returns a sort expression based on the descending order of the given column name, and null values appearbefore non-null values.
New in version 2.4.
pyspark.sql.functions.desc_nulls_last
pyspark.sql.functions.desc_nulls_last(col)Returns a sort expression based on the descending order of the given column name, and null values appear afternon-null values.
New in version 2.4.
180 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.element_at
pyspark.sql.functions.element_at(col, extraction)Collection function: Returns element of array at given index in extraction if col is array. Returns value for thegiven key in extraction if col is map.
New in version 2.4.0.
Parameters
col [Column or str] name of column containing array or map
extraction : index to check for in array or key to check for in map
Notes
The position is not zero based, but 1 based index.
Examples
>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])>>> df.select(element_at(df.data, 1)).collect()[Row(element_at(data, 1)='a'), Row(element_at(data, 1)=None)]
>>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])>>> df.select(element_at(df.data, lit("a"))).collect()[Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
pyspark.sql.functions.encode
pyspark.sql.functions.encode(col, charset)Computes the first argument into a binary from a string using the provided character set (one of ‘US-ASCII’,‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).
New in version 1.5.
pyspark.sql.functions.exists
pyspark.sql.functions.exists(col, f)Returns whether a predicate holds for one or more elements in the array.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
f [function] (x: Column) -> Column: ... returning the Boolean expression. Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).
:return: a :class:`pyspark.sql.Column`
3.1. Spark SQL 181
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key",→˓"values"))>>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show()+------------+|any_negative|+------------+| false|| true|+------------+
pyspark.sql.functions.exp
pyspark.sql.functions.exp(col)Computes the exponential of the given value.
New in version 1.4.
pyspark.sql.functions.explode
pyspark.sql.functions.explode(col)Returns a new row for each element in the given array or map. Uses the default column name col for elementsin the array and key and value for elements in the map unless specified otherwise.
New in version 1.4.0.
Examples
>>> from pyspark.sql import Row>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])>>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()[Row(anInt=1), Row(anInt=2), Row(anInt=3)]
>>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()+---+-----+|key|value|+---+-----+| a| b|+---+-----+
pyspark.sql.functions.explode_outer
pyspark.sql.functions.explode_outer(col)Returns a new row for each element in the given array or map. Unlike explode, if the array/map is null or emptythen null is produced. Uses the default column name col for elements in the array and key and value for elementsin the map unless specified otherwise.
New in version 2.3.0.
182 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame(... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],... ("id", "an_array", "a_map")... )>>> df.select("id", "an_array", explode_outer("a_map")).show()+---+----------+----+-----+| id| an_array| key|value|+---+----------+----+-----+| 1|[foo, bar]| x| 1.0|| 2| []|null| null|| 3| null|null| null|+---+----------+----+-----+
>>> df.select("id", "a_map", explode_outer("an_array")).show()+---+----------+----+| id| a_map| col|+---+----------+----+| 1|{x -> 1.0}| foo|| 1|{x -> 1.0}| bar|| 2| {}|null|| 3| null|null|+---+----------+----+
pyspark.sql.functions.expm1
pyspark.sql.functions.expm1(col)Computes the exponential of the given value minus one.
New in version 1.4.
pyspark.sql.functions.expr
pyspark.sql.functions.expr(str)Parses the expression string into the column that it represents
New in version 1.5.0.
Examples
>>> df.select(expr("length(name)")).collect()[Row(length(name)=5), Row(length(name)=3)]
3.1. Spark SQL 183
pyspark Documentation, Release master
pyspark.sql.functions.factorial
pyspark.sql.functions.factorial(col)Computes the factorial of the given value.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([(5,)], ['n'])>>> df.select(factorial(df.n).alias('f')).collect()[Row(f=120)]
pyspark.sql.functions.filter
pyspark.sql.functions.filter(col, f)Returns an array of elements for which a predicate holds in a given array.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
f [function] A function that returns the Boolean expression. Can take one of the followingforms:
• Unary (x: Column) -> Column: ...
• Binary (x: Column, i: Column) -> Column..., where the second argument isa 0-based index of the element.
and can use methods of pyspark.sql.Column, functions defined inpyspark.sql.functions and Scala UserDefinedFunctions. PythonUserDefinedFunctions are not supported (SPARK-27052).
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame(... [(1, ["2018-09-20", "2019-02-03", "2019-07-01", "2020-06-01"])],... ("key", "values")... )>>> def after_second_quarter(x):... return month(to_date(x)) > 6>>> df.select(... filter("values", after_second_quarter).alias("after_second_quarter")... ).show(truncate=False)+------------------------+|after_second_quarter |+------------------------+|[2018-09-20, 2019-07-01]|+------------------------+
184 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.first
pyspark.sql.functions.first(col, ignorenulls=False)Aggregate function: returns the first value in a group.
The function by default returns the first values it sees. It will return the first non-null value it sees when ig-noreNulls is set to true. If all values are null, then null is returned.
New in version 1.3.0.
Notes
The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle.
pyspark.sql.functions.flatten
pyspark.sql.functions.flatten(col)Collection function: creates a single array from an array of arrays. If a structure of nested arrays is deeper thantwo levels, only one level of nesting is removed.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], [→˓'data'])>>> df.select(flatten(df.data).alias('r')).collect()[Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
pyspark.sql.functions.floor
pyspark.sql.functions.floor(col)Computes the floor of the given value.
New in version 1.4.
pyspark.sql.functions.forall
pyspark.sql.functions.forall(col, f)Returns whether a predicate holds for every element in the array.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
3.1. Spark SQL 185
pyspark Documentation, Release master
f [function] (x: Column) -> Column: ... returning the Boolean expression. Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame(... [(1, ["bar"]), (2, ["foo", "bar"]), (3, ["foobar", "foo"])],... ("key", "values")... )>>> df.select(forall("values", lambda x: x.rlike("foo")).alias("all_foo")).show()+-------+|all_foo|+-------+| false|| false|| true|+-------+
pyspark.sql.functions.format_number
pyspark.sql.functions.format_number(col, d)Formats the number X to a format like ‘#,–#,–#.–’, rounded to d decimal places with HALF_EVEN round mode,and returns the result as a string.
New in version 1.5.0.
Parameters
col [Column or str] the column name of the numeric value to be formatted
d [int] the N decimal places
>>> spark.createDataFrame([(5,)], [‘a’]).select(format_number(‘a’, 4).alias(‘v’)).collect()
[Row(v=’5.0000’)]
pyspark.sql.functions.format_string
pyspark.sql.functions.format_string(format, *cols)Formats the arguments in printf-style and returns the result as a string column.
New in version 1.5.0.
Parameters
format [str] string that can contain embedded format tags and used as result column’s value
cols [Column or str] column names or Columns to be used in formatting
186 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([(5, "hello")], ['a', 'b'])>>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()[Row(v='5 hello')]
pyspark.sql.functions.from_csv
pyspark.sql.functions.from_csv(col, schema, options=None)Parses a column containing a CSV string to a row with the specified schema. Returns null, in the case of anunparseable string.
New in version 3.0.0.
Parameters
col [Column or str] string column in CSV format
schema :class:`Column` or str a string with schema in DDL format to use when parsing theCSV column.
options [dict, optional] options to control parsing. accepts the same options as the CSV data-source
Examples
>>> data = [("1,2,3",)]>>> df = spark.createDataFrame(data, ("value",))>>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect()[Row(csv=Row(a=1, b=2, c=3))]>>> value = data[0][0]>>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()[Row(csv=Row(_c0=1, _c1=2, _c2=3))]>>> data = [(" abc",)]>>> df = spark.createDataFrame(data, ("value",))>>> options = {'ignoreLeadingWhiteSpace': True}>>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect()[Row(csv=Row(s='abc'))]
pyspark.sql.functions.from_json
pyspark.sql.functions.from_json(col, schema, options=None)Parses a column containing a JSON string into a MapType with StringType as keys type, StructType orArrayType with the specified schema. Returns null, in the case of an unparseable string.
New in version 2.1.0.
Parameters
col [Column or str] string column in json format
schema [DataType or str] a StructType or ArrayType of StructType to use when parsing thejson column.
Changed in version 2.3: the DDL-formatted string is also supported for schema.
3.1. Spark SQL 187
pyspark Documentation, Release master
options [dict, optional] options to control parsing. accepts the same options as the json data-source
Examples
>>> from pyspark.sql.types import *>>> data = [(1, '''{"a": 1}''')]>>> schema = StructType([StructField("a", IntegerType())])>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=Row(a=1))]>>> df.select(from_json(df.value, "a INT").alias("json")).collect()[Row(json=Row(a=1))]>>> df.select(from_json(df.value, "MAP<STRING,INT>").alias("json")).collect()[Row(json={'a': 1})]>>> data = [(1, '''[{"a": 1}]''')]>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=[Row(a=1)])]>>> schema = schema_of_json(lit('''{"a": 0}'''))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=Row(a=None))]>>> data = [(1, '''[1, 2, 3]''')]>>> schema = ArrayType(IntegerType())>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=[1, 2, 3])]
pyspark.sql.functions.from_unixtime
pyspark.sql.functions.from_unixtime(timestamp, format='yyyy-MM-dd HH:mm:ss')Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing thetimestamp of that moment in the current system time zone in the given format.
New in version 1.5.0.
Examples
>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")>>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time'])>>> time_df.select(from_unixtime('unix_time').alias('ts')).collect()[Row(ts='2015-04-08 00:00:00')]>>> spark.conf.unset("spark.sql.session.timeZone")
188 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.from_utc_timestamp
pyspark.sql.functions.from_utc_timestamp(timestamp, tz)This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function takesa timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and renders that timestamp asa timestamp in the given time zone.
However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to the given timezone.
This function may return confusing result if the input is a string with timezone, e.g. ‘2018-03-13T06:18:23+00:00’. The reason is that, Spark firstly cast the string to timestamp according to the timezonein the string, and finally display the result by converting the timestamp to string according to the session localtimezone.
New in version 1.5.0.
Parameters
timestamp [Column or str] the column that contains timestamps
tz [Column or str] A string detailing the time zone ID that the input should be adjusted to. Itshould be in the format of either region-based zone IDs or zone offsets. Region IDs musthave the form ‘area/city’, such as ‘America/Los_Angeles’. Zone offsets must be in theformat ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also ‘UTC’ and ‘Z’ are supportedas aliases of ‘+00:00’. Other short names are not recommended to use because they can beambiguous.
Changed in version 2.4: tz can take a Column containing timezone ID strings.
Examples
>>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])>>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect()[Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))]>>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect()[Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))]
pyspark.sql.functions.get_json_object
pyspark.sql.functions.get_json_object(col, path)Extracts json object from a json string based on json path specified, and returns json string of the extracted jsonobject. It will return null if the input json string is invalid.
New in version 1.6.0.
Parameters
col [Column or str] string column in json format
path [str] path to the json object to extract
3.1. Spark SQL 189
pyspark Documentation, Release master
Examples
>>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1":→˓"value12"}''')]>>> df = spark.createDataFrame(data, ("key", "jstring"))>>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \... get_json_object(df.jstring, '$.f2').alias("c1") ).collect()[Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]
pyspark.sql.functions.greatest
pyspark.sql.functions.greatest(*cols)Returns the greatest value of the list of column names, skipping null values. This function takes at least 2parameters. It will return null iff all parameters are null.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])>>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()[Row(greatest=4)]
pyspark.sql.functions.grouping
pyspark.sql.functions.grouping(col)Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or not, returns 1for aggregated or 0 for not aggregated in the result set.
New in version 2.0.0.
Examples
>>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show()+-----+--------------+--------+| name|grouping(name)|sum(age)|+-----+--------------+--------+| null| 1| 7||Alice| 0| 2|| Bob| 0| 5|+-----+--------------+--------+
190 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.grouping_id
pyspark.sql.functions.grouping_id(*cols)Aggregate function: returns the level of grouping, equals to
(grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + . . . + grouping(cn)
New in version 2.0.0.
Notes
The list of columns should match with grouping columns exactly, or empty (means all the grouping columns).
Examples
>>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()+-----+-------------+--------+| name|grouping_id()|sum(age)|+-----+-------------+--------+| null| 1| 7||Alice| 0| 2|| Bob| 0| 5|+-----+-------------+--------+
pyspark.sql.functions.hash
pyspark.sql.functions.hash(*cols)Calculates the hash code of given columns, and returns the result as an int column.
New in version 2.0.0.
Examples
>>> spark.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).→˓collect()[Row(hash=-757602832)]
pyspark.sql.functions.hex
pyspark.sql.functions.hex(col)Computes hex value of the given column, which could be pyspark.sql.types.StringType,pyspark.sql.types.BinaryType, pyspark.sql.types.IntegerType or pyspark.sql.types.LongType.
New in version 1.5.0.
3.1. Spark SQL 191
pyspark Documentation, Release master
Examples
>>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).→˓collect()[Row(hex(a)='414243', hex(b)='3')]
pyspark.sql.functions.hour
pyspark.sql.functions.hour(col)Extract the hours of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])>>> df.select(hour('ts').alias('hour')).collect()[Row(hour=13)]
pyspark.sql.functions.hours
pyspark.sql.functions.hours(col)Partition transform function: A transform for timestamps to partition data into hours.
New in version 3.1.0.
Notes
This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.
Examples
>>> df.writeTo("catalog.db.table").partitionedBy(... hours("ts")... ).createOrReplace()
pyspark.sql.functions.hypot
pyspark.sql.functions.hypot(col1, col2)Computes sqrt(a^2 + b^2) without intermediate overflow or underflow.
New in version 1.4.
192 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.initcap
pyspark.sql.functions.initcap(col)Translate the first letter of each word to upper case in the sentence.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).→˓collect()[Row(v='Ab Cd')]
pyspark.sql.functions.input_file_name
pyspark.sql.functions.input_file_name()Creates a string column for the file name of the current Spark task.
New in version 1.6.
pyspark.sql.functions.instr
pyspark.sql.functions.instr(str, substr)Locate the position of the first occurrence of substr column in the given string. Returns null if either of thearguments are null.
New in version 1.5.0.
Notes
The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str.
>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(instr(df.s, 'b').alias('s')).collect()[Row(s=2)]
pyspark.sql.functions.isnan
pyspark.sql.functions.isnan(col)An expression that returns true iff the column is NaN.
New in version 1.6.0.
3.1. Spark SQL 193
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a",→˓"b"))>>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect()[Row(r1=False, r2=False), Row(r1=True, r2=True)]
pyspark.sql.functions.isnull
pyspark.sql.functions.isnull(col)An expression that returns true iff the column is null.
New in version 1.6.0.
Examples
>>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b"))>>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect()[Row(r1=False, r2=False), Row(r1=True, r2=True)]
pyspark.sql.functions.json_tuple
pyspark.sql.functions.json_tuple(col, *fields)Creates a new row for a json column according to the given field names.
New in version 1.6.0.
Parameters
col [Column or str] string column in json format
fields [str] fields to extract
Examples
>>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1":→˓"value12"}''')]>>> df = spark.createDataFrame(data, ("key", "jstring"))>>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()[Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]
pyspark.sql.functions.kurtosis
pyspark.sql.functions.kurtosis(col)Aggregate function: returns the kurtosis of the values in a group.
New in version 1.6.
194 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.lag
pyspark.sql.functions.lag(col, offset=1, default=None)Window function: returns the value that is offset rows before the current row, and defaultValue if there is lessthan offset rows before the current row. For example, an offset of one will return the previous row at any givenpoint in the window partition.
This is equivalent to the LAG function in SQL.
New in version 1.4.0.
Parameters
col [Column or str] name of column or expression
offset [int, optional] number of row to extend
default [optional] default value
pyspark.sql.functions.last
pyspark.sql.functions.last(col, ignorenulls=False)Aggregate function: returns the last value in a group.
The function by default returns the last values it sees. It will return the last non-null value it sees when ig-noreNulls is set to true. If all values are null, then null is returned.
New in version 1.3.0.
Notes
The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle.
pyspark.sql.functions.last_day
pyspark.sql.functions.last_day(date)Returns the last day of the month which the given date belongs to.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('1997-02-10',)], ['d'])>>> df.select(last_day(df.d).alias('date')).collect()[Row(date=datetime.date(1997, 2, 28))]
3.1. Spark SQL 195
pyspark Documentation, Release master
pyspark.sql.functions.lead
pyspark.sql.functions.lead(col, offset=1, default=None)Window function: returns the value that is offset rows after the current row, and defaultValue if there is less thanoffset rows after the current row. For example, an offset of one will return the next row at any given point in thewindow partition.
This is equivalent to the LEAD function in SQL.
New in version 1.4.0.
Parameters
col [Column or str] name of column or expression
offset [int, optional] number of row to extend
default [optional] default value
pyspark.sql.functions.least
pyspark.sql.functions.least(*cols)Returns the least value of the list of column names, skipping null values. This function takes at least 2 parame-ters. It will return null iff all parameters are null.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])>>> df.select(least(df.a, df.b, df.c).alias("least")).collect()[Row(least=1)]
pyspark.sql.functions.length
pyspark.sql.functions.length(col)Computes the character length of string data or number of bytes of binary data. The length of character dataincludes the trailing spaces. The length of binary data includes binary zeros.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).→˓collect()[Row(length=4)]
196 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.levenshtein
pyspark.sql.functions.levenshtein(left, right)Computes the Levenshtein distance of the two given strings.
New in version 1.5.0.
Examples
>>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])>>> df0.select(levenshtein('l', 'r').alias('d')).collect()[Row(d=3)]
pyspark.sql.functions.lit
pyspark.sql.functions.lit(col)Creates a Column of literal value.
New in version 1.3.0.
Examples
>>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)[Row(height=5, spark_user=True)]
pyspark.sql.functions.locate
pyspark.sql.functions.locate(substr, str, pos=1)Locate the position of the first occurrence of substr in a string column, after position pos.
New in version 1.5.0.
Parameters
substr [str] a string
str [Column or str] a Column of pyspark.sql.types.StringType
pos [int, optional] start position (zero based)
Notes
The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str.
3.1. Spark SQL 197
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(locate('b', df.s, 1).alias('s')).collect()[Row(s=2)]
pyspark.sql.functions.log
pyspark.sql.functions.log(arg1, arg2=None)Returns the first argument-based logarithm of the second argument.
If there is only one argument, then this takes the natural logarithm of the argument.
New in version 1.5.0.
Examples
>>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).→˓collect()['0.30102', '0.69897']
>>> df.select(log(df.age).alias('e')).rdd.map(lambda l: str(l.e)[:7]).collect()['0.69314', '1.60943']
pyspark.sql.functions.log10
pyspark.sql.functions.log10(col)Computes the logarithm of the given value in Base 10.
New in version 1.4.
pyspark.sql.functions.log1p
pyspark.sql.functions.log1p(col)Computes the natural logarithm of the given value plus one.
New in version 1.4.
pyspark.sql.functions.log2
pyspark.sql.functions.log2(col)Returns the base-2 logarithm of the argument.
New in version 1.5.0.
198 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect()[Row(log2=2.0)]
pyspark.sql.functions.lower
pyspark.sql.functions.lower(col)Converts a string expression to lower case.
New in version 1.5.
pyspark.sql.functions.lpad
pyspark.sql.functions.lpad(col, len, pad)Left-pad the string column to width len with pad.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(lpad(df.s, 6, '#').alias('s')).collect()[Row(s='##abcd')]
pyspark.sql.functions.ltrim
pyspark.sql.functions.ltrim(col)Trim the spaces from left end for the specified string value.
New in version 1.5.
pyspark.sql.functions.map_concat
pyspark.sql.functions.map_concat(*cols)Returns the union of all the given maps.
New in version 2.4.0.
Parameters
cols [Column or str] column names or Columns
3.1. Spark SQL 199
pyspark Documentation, Release master
Examples
>>> from pyspark.sql.functions import map_concat>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2")>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)+------------------------+|map3 |+------------------------+|{1 -> a, 2 -> b, 3 -> c}|+------------------------+
pyspark.sql.functions.map_entries
pyspark.sql.functions.map_entries(col)Collection function: Returns an unordered array of all entries in the given map.
New in version 3.0.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> from pyspark.sql.functions import map_entries>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")>>> df.select(map_entries("data").alias("entries")).show()+----------------+| entries|+----------------+|[{1, a}, {2, b}]|+----------------+
pyspark.sql.functions.map_filter
pyspark.sql.functions.map_filter(col, f)Returns a map whose key-value pairs satisfy a predicate.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
f [function] a binary function (k: Column, v: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).
Returns
pyspark.sql.Column
200 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([(1, {"foo": 42.0, "bar": 1.0, "baz": 32.0})], ("id→˓", "data"))>>> df.select(map_filter(... "data", lambda _, v: v > 30.0).alias("data_filtered")... ).show(truncate=False)+--------------------------+|data_filtered |+--------------------------+|{baz -> 32.0, foo -> 42.0}|+--------------------------+
pyspark.sql.functions.map_from_arrays
pyspark.sql.functions.map_from_arrays(col1, col2)Creates a new map from two arrays.
New in version 2.4.0.
Parameters
col1 [Column or str] name of column containing a set of keys. All elements should not be null
col2 [Column or str] name of column containing a set of values
Examples
>>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])>>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()+----------------+| map|+----------------+|{2 -> a, 5 -> b}|+----------------+
pyspark.sql.functions.map_from_entries
pyspark.sql.functions.map_from_entries(col)Collection function: Returns a map created from the given array of entries.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
3.1. Spark SQL 201
pyspark Documentation, Release master
Examples
>>> from pyspark.sql.functions import map_from_entries>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")>>> df.select(map_from_entries("data").alias("map")).show()+----------------+| map|+----------------+|{1 -> a, 2 -> b}|+----------------+
pyspark.sql.functions.map_keys
pyspark.sql.functions.map_keys(col)Collection function: Returns an unordered array containing the keys of the map.
New in version 2.3.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> from pyspark.sql.functions import map_keys>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")>>> df.select(map_keys("data").alias("keys")).show()+------+| keys|+------+|[1, 2]|+------+
pyspark.sql.functions.map_values
pyspark.sql.functions.map_values(col)Collection function: Returns an unordered array containing the values of the map.
New in version 2.3.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> from pyspark.sql.functions import map_values>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")>>> df.select(map_values("data").alias("values")).show()+------+|values|+------+|[a, b]|+------+
202 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.map_zip_with
pyspark.sql.functions.map_zip_with(col1, col2, f)Merge two given maps, key-wise into a single map using a function.
New in version 3.1.0.
Parameters
col1 [Column or str] name of the first column or expression
col2 [Column or str] name of the second column or expression
f [function] a ternary function (k: Column, v1: Column, v2: Column) ->Column... Can use methods of pyspark.sql.Column, functions definedin pyspark.sql.functions and Scala UserDefinedFunctions. PythonUserDefinedFunctions are not supported (SPARK-27052).
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame([... (1, {"IT": 24.0, "SALES": 12.00}, {"IT": 2.0, "SALES": 1.4})],... ("id", "base", "ratio")... )>>> df.select(map_zip_with(... "base", "ratio", lambda k, v1, v2: round(v1 * v2, 2)).alias("updated_data→˓")... ).show(truncate=False)+---------------------------+|updated_data |+---------------------------+|{SALES -> 16.8, IT -> 48.0}|+---------------------------+
pyspark.sql.functions.max
pyspark.sql.functions.max(col)Aggregate function: returns the maximum value of the expression in a group.
New in version 1.3.
pyspark.sql.functions.md5
pyspark.sql.functions.md5(col)Calculates the MD5 digest and returns the value as a 32 character hex string.
New in version 1.5.0.
3.1. Spark SQL 203
pyspark Documentation, Release master
Examples
>>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).→˓collect()[Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')]
pyspark.sql.functions.mean
pyspark.sql.functions.mean(col)Aggregate function: returns the average of the values in a group.
New in version 1.3.
pyspark.sql.functions.min
pyspark.sql.functions.min(col)Aggregate function: returns the minimum value of the expression in a group.
New in version 1.3.
pyspark.sql.functions.minute
pyspark.sql.functions.minute(col)Extract the minutes of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])>>> df.select(minute('ts').alias('minute')).collect()[Row(minute=8)]
pyspark.sql.functions.monotonically_increasing_id
pyspark.sql.functions.monotonically_increasing_id()A column that generates monotonically increasing 64-bit integers.
The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. The currentimplementation puts the partition ID in the upper 31 bits, and the record number within each partition in thelower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has lessthan 8 billion records.
New in version 1.6.0.
204 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
The function is non-deterministic because its result depends on partition IDs.
As an example, consider a DataFrame with two partitions, each with 3 records. This expression would returnthe following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
>>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).→˓toDF(['col1'])>>> df0.select(monotonically_increasing_id().alias('id')).collect()[Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593),→˓Row(id=8589934594)]
pyspark.sql.functions.month
pyspark.sql.functions.month(col)Extract the month of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(month('dt').alias('month')).collect()[Row(month=4)]
pyspark.sql.functions.months
pyspark.sql.functions.months(col)Partition transform function: A transform for timestamps and dates to partition data into months.
New in version 3.1.0.
Notes
This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.
Examples
>>> df.writeTo("catalog.db.table").partitionedBy(... months("ts")... ).createOrReplace()
3.1. Spark SQL 205
pyspark Documentation, Release master
pyspark.sql.functions.months_between
pyspark.sql.functions.months_between(date1, date2, roundOff=True)Returns number of months between dates date1 and date2. If date1 is later than date2, then the result is positive.If date1 and date2 are on the same day of month, or both are the last day of month, returns an integer (time ofday will be ignored). The result is rounded off to 8 digits unless roundOff is set to False.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1',→˓'date2'])>>> df.select(months_between(df.date1, df.date2).alias('months')).collect()[Row(months=3.94959677)]>>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect()[Row(months=3.9495967741935485)]
pyspark.sql.functions.nanvl
pyspark.sql.functions.nanvl(col1, col2)Returns col1 if it is not NaN, or col2 if col1 is NaN.
Both inputs should be floating point columns (DoubleType or FloatType).
New in version 1.6.0.
Examples
>>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a",→˓"b"))>>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).→˓collect()[Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)]
pyspark.sql.functions.next_day
pyspark.sql.functions.next_day(date, dayOfWeek)Returns the first date which is later than the value of the date column.
Day of the week parameter is case insensitive, and accepts: “Mon”, “Tue”, “Wed”, “Thu”, “Fri”, “Sat”,“Sun”.
New in version 1.5.0.
206 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([('2015-07-27',)], ['d'])>>> df.select(next_day(df.d, 'Sun').alias('date')).collect()[Row(date=datetime.date(2015, 8, 2))]
pyspark.sql.functions.nth_value
pyspark.sql.functions.nth_value(col, offset, ignoreNulls=False)Window function: returns the value that is the offsetth row of the window frame (counting from 1), and null ifthe size of window frame is less than offset rows.
It will return the offsetth non-null value it sees when ignoreNulls is set to true. If all values are null, then null isreturned.
This is equivalent to the nth_value function in SQL.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
offset [int, optional] number of row to use as the value
ignoreNulls [bool, optional] indicates the Nth value should skip null in the determination ofwhich row to use
pyspark.sql.functions.ntile
pyspark.sql.functions.ntile(n)Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window partition. For example,if n is 4, the first quarter of the rows will get value 1, the second quarter will get 2, the third quarter will get 3,and the last quarter will get 4.
This is equivalent to the NTILE function in SQL.
New in version 1.4.0.
Parameters
n [int] an integer
pyspark.sql.functions.overlay
pyspark.sql.functions.overlay(src, replace, pos, len=- 1)Overlay the specified portion of src with replace, starting from byte position pos of src and proceeding for lenbytes.
New in version 3.0.0.
3.1. Spark SQL 207
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([("SPARK_SQL", "CORE")], ("x", "y"))>>> df.select(overlay("x", "y", 7).alias("overlayed")).show()+----------+| overlayed|+----------+|SPARK_CORE|+----------+
pyspark.sql.functions.pandas_udf
pyspark.sql.functions.pandas_udf(f=None, returnType=None, functionType=None)Creates a pandas user defined function (a.k.a. vectorized user defined function).
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandasto work with the data, which allows vectorized operations. A Pandas UDF is defined using the pandas_udf asa decorator or to wrap the function, and no additional configuration is required. A Pandas UDF behaves as aregular PySpark function API in general.
New in version 2.3.0.
Parameters
f [function, optional] user-defined function. A python function if used as a standalone function
returnType [pyspark.sql.types.DataType or str, optional] the return type of the user-defined function. The value can be either a pyspark.sql.types.DataType object ora DDL-formatted type string.
functionType [int, optional] an enum value in pyspark.sql.functions.PandasUDFType. Default: SCALAR. This parameter exists for compatibility. UsingPython type hints is encouraged.
See also:
pyspark.sql.GroupedData.agg
pyspark.sql.DataFrame.mapInPandas
pyspark.sql.GroupedData.applyInPandas
pyspark.sql.PandasCogroupedOps.applyInPandas
pyspark.sql.UDFRegistration.register
Notes
The user-defined functions do not support conditional expressions or short circuiting in boolean expressions andit ends up with being executed all internally. If the functions can fail on special rows, the workaround is toincorporate the condition into the functions.
The user-defined functions do not take keyword arguments on the calling side.
The data type of returned pandas.Series from the user-defined functions should be matched with defined re-turnType (see types.to_arrow_type() and types.from_arrow_type()). When there is mismatchbetween them, Spark might do conversion on returned data. The conversion is not guaranteed to be correct andresults should be checked for accuracy by users.
208 Chapter 3. API Reference
pyspark Documentation, Release master
Currently, pyspark.sql.types.ArrayType of pyspark.sql.types.TimestampType andnested pyspark.sql.types.StructType are currently not supported as output types.
Examples
In order to use this API, customarily the below are imported:
>>> import pandas as pd>>> from pyspark.sql.functions import pandas_udf
From Spark 3.0 with Python 3.6+, Python type hints detect the function types as below:
>>> @pandas_udf(IntegerType())... def slen(s: pd.Series) -> pd.Series:... return s.str.len()
Prior to Spark 3.0, the pandas UDF used functionType to decide the execution type as below:
>>> from pyspark.sql.functions import PandasUDFType>>> from pyspark.sql.types import IntegerType>>> @pandas_udf(IntegerType(), PandasUDFType.SCALAR)... def slen(s):... return s.str.len()
It is preferred to specify type hints for the pandas UDF instead of specifying pandas UDF type via functionTypewhich will be deprecated in the future releases.
Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrameshould be used for its input or output type hint instead when the input or output column is of pyspark.sql.types.StructType. The following example shows a Pandas UDF which takes long column, string columnand struct column, and outputs a struct column. It requires the function to specify the type hints of pandas.Seriesand pandas.DataFrame as below:
>>> @pandas_udf("col1 string, col2 long")>>> def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:... s3['col2'] = s1 + s2.str.len()... return s3...>>> # Create a Spark DataFrame that has three columns including a struct column.... df = spark.createDataFrame(... [[1, "a string", ("a nested string",)]],... "long_col long, string_col string, struct_col struct<col1:string>")>>> df.printSchema()root|-- long_column: long (nullable = true)|-- string_column: string (nullable = true)|-- struct_column: struct (nullable = true)| |-- col1: string (nullable = true)>>> df.select(func("long_col", "string_col", "struct_col")).printSchema()|-- func(long_col, string_col, struct_col): struct (nullable = true)| |-- col1: string (nullable = true)| |-- col2: long (nullable = true)
In the following sections, it describes the combinations of the supported type hints. For simplicity, pan-das.DataFrame variant is omitted.
• Series to Series pandas.Series, . . . -> pandas.Series
3.1. Spark SQL 209
pyspark Documentation, Release master
The function takes one or more pandas.Series and outputs one pandas.Series. The output of thefunction should always be of the same length as the input.
>>> @pandas_udf("string")... def to_upper(s: pd.Series) -> pd.Series:... return s.str.upper()...>>> df = spark.createDataFrame([("John Doe",)], ("name",))>>> df.select(to_upper("name")).show()+--------------+|to_upper(name)|+--------------+| JOHN DOE|+--------------+
>>> @pandas_udf("first string, last string")... def split_expand(s: pd.Series) -> pd.DataFrame:... return s.str.split(expand=True)...>>> df = spark.createDataFrame([("John Doe",)], ("name",))>>> df.select(split_expand("name")).show()+------------------+|split_expand(name)|+------------------+| [John, Doe]|+------------------+
Note: The length of the input is not that of the whole input column, but is the length of an internalbatch used for each call to the function.
• Iterator of Series to Iterator of Series Iterator[pandas.Series] -> Iterator[pandas.Series]
The function takes an iterator of pandas.Series and outputs an iterator of pandas.Series. In this case,the created pandas UDF instance requires one input column when this is called as a PySpark col-umn. The length of the entire output from the function should be the same length of the entire input;therefore, it can prefetch the data from the input iterator as long as the lengths are the same.
It is also useful when the UDF execution requires initializing some states although internally it worksidentically as Series to Series case. The pseudocode below illustrates the example.
@pandas_udf("long")def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# Do some expensive initialization with a statestate = very_expensive_initialization()for x in iterator:
# Use that state for whole iterator.yield calculate_with_state(x, state)
df.select(calculate("value")).show()
>>> from typing import Iterator>>> @pandas_udf("long")... def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:... for s in iterator:... yield s + 1...
(continues on next page)
210 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))>>> df.select(plus_one(df.v)).show()+-----------+|plus_one(v)|+-----------+| 2|| 3|| 4|+-----------+
Note: The length of each series is the length of a batch internally used.
• Iterator of Multiple Series to Iterator of Series Iterator[Tuple[pandas.Series, . . . ]] -> Itera-tor[pandas.Series]
The function takes an iterator of a tuple of multiple pandas.Series and outputs an iterator of pan-das.Series. In this case, the created pandas UDF instance requires input columns as many as the serieswhen this is called as a PySpark column. Otherwise, it has the same characteristics and restrictions asIterator of Series to Iterator of Series case.
>>> from typing import Iterator, Tuple>>> from pyspark.sql.functions import struct, col>>> @pandas_udf("long")... def multiply(iterator: Iterator[Tuple[pd.Series, pd.DataFrame]]) ->→˓Iterator[pd.Series]:... for s1, df in iterator:... yield s1 * df.v...>>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))>>> df.withColumn('output', multiply(col("v"), struct(col("v")))).show()+---+------+| v|output|+---+------+| 1| 1|| 2| 4|| 3| 9|+---+------+
Note: The length of each series is the length of a batch internally used.
• Series to Scalar pandas.Series, . . . -> Any
The function takes pandas.Series and returns a scalar value. The returnType should be a primitive datatype, and the returned scalar can be either a python primitive type, e.g., int or float or a numpy datatype, e.g., numpy.int64 or numpy.float64. Any should ideally be a specific scalar type accordingly.
>>> @pandas_udf("double")... def mean_udf(v: pd.Series) -> float:... return v.mean()...>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))>>> df.groupby("id").agg(mean_udf(df['v'])).show()
(continues on next page)
3.1. Spark SQL 211
pyspark Documentation, Release master
(continued from previous page)
+---+-----------+| id|mean_udf(v)|+---+-----------+| 1| 1.5|| 2| 6.0|+---+-----------+
This UDF can also be used as window functions as below:
>>> from pyspark.sql import Window>>> @pandas_udf("double")... def mean_udf(v: pd.Series) -> float:... return v.mean()...>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))>>> w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0)>>> df.withColumn('mean_v', mean_udf("v").over(w)).show()+---+----+------+| id| v|mean_v|+---+----+------+| 1| 1.0| 1.0|| 1| 2.0| 1.5|| 2| 3.0| 3.0|| 2| 5.0| 4.0|| 2|10.0| 7.5|+---+----+------+
Note: For performance reasons, the input series to window functions are not copied. Therefore,mutating the input series is not allowed and will cause incorrect results. For the same reason, usersshould also not rely on the index of the input series.
pyspark.sql.functions.percent_rank
pyspark.sql.functions.percent_rank()Window function: returns the relative rank (i.e. percentile) of rows within a window partition.
New in version 1.6.
pyspark.sql.functions.percentile_approx
pyspark.sql.functions.percentile_approx(col, percentage, accuracy=10000)Returns the approximate percentile of the numeric column col which is the smallest value in the ordered colvalues (sorted from least to greatest) such that no more than percentage of col values is less than the value orequal to that value. The value of percentage must be between 0.0 and 1.0.
The accuracy parameter (default: 10000) is a positive numeric literal which controls approximation accuracyat the cost of memory. Higher value of accuracy yields better accuracy, 1.0/accuracy is the relative error of theapproximation.
When percentage is an array, each value of the percentage array must be between 0.0 and 1.0. In this case,returns the approximate percentile array of column col at the given percentage array.
New in version 3.1.0.
212 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> key = (col("id") % 3).alias("key")>>> value = (randn(42) + key * 10).alias("value")>>> df = spark.range(0, 1000, 1, 1).select(key, value)>>> df.select(... percentile_approx("value", [0.25, 0.5, 0.75], 1000000).alias("quantiles")... ).printSchema()root|-- quantiles: array (nullable = true)| |-- element: double (containsNull = false)
>>> df.groupBy("key").agg(... percentile_approx("value", 0.5, lit(1000000)).alias("median")... ).printSchema()root|-- key: long (nullable = true)|-- median: double (nullable = true)
pyspark.sql.functions.posexplode
pyspark.sql.functions.posexplode(col)Returns a new row for each element with position in the given array or map. Uses the default column namepos for position, and col for elements in the array and key and value for elements in the map unless specifiedotherwise.
New in version 2.1.0.
Examples
>>> from pyspark.sql import Row>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])>>> eDF.select(posexplode(eDF.intlist)).collect()[Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
>>> eDF.select(posexplode(eDF.mapfield)).show()+---+---+-----+|pos|key|value|+---+---+-----+| 0| a| b|+---+---+-----+
pyspark.sql.functions.posexplode_outer
pyspark.sql.functions.posexplode_outer(col)Returns a new row for each element with position in the given array or map. Unlike posexplode, if the array/mapis null or empty then the row (null, null) is produced. Uses the default column name pos for position, and colfor elements in the array and key and value for elements in the map unless specified otherwise.
New in version 2.3.0.
3.1. Spark SQL 213
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame(... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],... ("id", "an_array", "a_map")... )>>> df.select("id", "an_array", posexplode_outer("a_map")).show()+---+----------+----+----+-----+| id| an_array| pos| key|value|+---+----------+----+----+-----+| 1|[foo, bar]| 0| x| 1.0|| 2| []|null|null| null|| 3| null|null|null| null|+---+----------+----+----+-----+>>> df.select("id", "a_map", posexplode_outer("an_array")).show()+---+----------+----+----+| id| a_map| pos| col|+---+----------+----+----+| 1|{x -> 1.0}| 0| foo|| 1|{x -> 1.0}| 1| bar|| 2| {}|null|null|| 3| null|null|null|+---+----------+----+----+
pyspark.sql.functions.pow
pyspark.sql.functions.pow(col1, col2)Returns the value of the first argument raised to the power of the second argument.
New in version 1.4.
pyspark.sql.functions.quarter
pyspark.sql.functions.quarter(col)Extract the quarter of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(quarter('dt').alias('quarter')).collect()[Row(quarter=2)]
214 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.radians
pyspark.sql.functions.radians(col)Converts an angle measured in degrees to an approximately equivalent angle measured in radians.
New in version 2.1.0.
Parameters
col [Column or str] angle in degrees
Returns
Column angle in radians, as if computed by java.lang.Math.toRadians()
pyspark.sql.functions.raise_error
pyspark.sql.functions.raise_error(errMsg)Throws an exception with the provided error message.
New in version 3.1.
pyspark.sql.functions.rand
pyspark.sql.functions.rand(seed=None)Generates a random column with independent and identically distributed (i.i.d.) samples uniformly distributedin [0.0, 1.0).
New in version 1.4.0.
Notes
The function is non-deterministic in general case.
Examples
>>> df.withColumn('rand', rand(seed=42) * 3).collect()[Row(age=2, name='Alice', rand=2.4052597283576684),Row(age=5, name='Bob', rand=2.3913904055683974)]
pyspark.sql.functions.randn
pyspark.sql.functions.randn(seed=None)Generates a column with independent and identically distributed (i.i.d.) samples from the standard normaldistribution.
New in version 1.4.0.
3.1. Spark SQL 215
pyspark Documentation, Release master
Notes
The function is non-deterministic in general case.
Examples
>>> df.withColumn('randn', randn(seed=42)).collect()[Row(age=2, name='Alice', randn=1.1027054481455365),Row(age=5, name='Bob', randn=0.7400395449950132)]
pyspark.sql.functions.rank
pyspark.sql.functions.rank()Window function: returns the rank of rows within a window partition.
The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking sequence when thereare ties. That is, if you were ranking a competition using dense_rank and had three people tie for second place,you would say that all three were in second place and that the next person came in third. Rank would give mesequential numbers, making the person that came in third place (after the ties) would register as coming in fifth.
This is equivalent to the RANK function in SQL.
New in version 1.6.
pyspark.sql.functions.regexp_extract
pyspark.sql.functions.regexp_extract(str, pattern, idx)Extract a specific group matched by a Java regex, from the specified string column. If the regex did not match,or the specified group did not match, an empty string is returned.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('100-200',)], ['str'])>>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect()[Row(d='100')]>>> df = spark.createDataFrame([('foo',)], ['str'])>>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect()[Row(d='')]>>> df = spark.createDataFrame([('aaaac',)], ['str'])>>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()[Row(d='')]
216 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.regexp_replace
pyspark.sql.functions.regexp_replace(str, pattern, replacement)Replace all substrings of the specified string value that match regexp with rep.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('100-200',)], ['str'])>>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect()[Row(d='-----')]
pyspark.sql.functions.repeat
pyspark.sql.functions.repeat(col, n)Repeats a string column n times, and returns it as a new string column.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('ab',)], ['s',])>>> df.select(repeat(df.s, 3).alias('s')).collect()[Row(s='ababab')]
pyspark.sql.functions.reverse
pyspark.sql.functions.reverse(col)Collection function: returns a reversed string or an array with reverse order of elements.
New in version 1.5.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])>>> df.select(reverse(df.data).alias('s')).collect()[Row(s='LQS krapS')]>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])>>> df.select(reverse(df.data).alias('r')).collect()[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
3.1. Spark SQL 217
pyspark Documentation, Release master
pyspark.sql.functions.rint
pyspark.sql.functions.rint(col)Returns the double value that is closest in value to the argument and is equal to a mathematical integer.
New in version 1.4.
pyspark.sql.functions.round
pyspark.sql.functions.round(col, scale=0)Round the given value to scale decimal places using HALF_UP rounding mode if scale >= 0 or at integral partwhen scale < 0.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).→˓collect()[Row(r=3.0)]
pyspark.sql.functions.row_number
pyspark.sql.functions.row_number()Window function: returns a sequential number starting at 1 within a window partition.
New in version 1.6.
pyspark.sql.functions.rpad
pyspark.sql.functions.rpad(col, len, pad)Right-pad the string column to width len with pad.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(rpad(df.s, 6, '#').alias('s')).collect()[Row(s='abcd##')]
pyspark.sql.functions.rtrim
pyspark.sql.functions.rtrim(col)Trim the spaces from right end for the specified string value.
New in version 1.5.
218 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.schema_of_csv
pyspark.sql.functions.schema_of_csv(csv, options=None)Parses a CSV string and infers its schema in DDL format.
New in version 3.0.0.
Parameters
csv [Column or str] a CSV string or a foldable string column containing a CSV string.
options [dict, optional] options to control parsing. accepts the same options as the CSV data-source
Examples
>>> df = spark.range(1)>>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect()[Row(csv='STRUCT<`_c0`: INT, `_c1`: STRING>')]>>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()[Row(csv='STRUCT<`_c0`: INT, `_c1`: STRING>')]
pyspark.sql.functions.schema_of_json
pyspark.sql.functions.schema_of_json(json, options=None)Parses a JSON string and infers its schema in DDL format.
New in version 2.4.0.
Parameters
json [Column or str] a JSON string or a foldable string column containing a JSON string.
options [dict, optional] options to control parsing. accepts the same options as the JSON data-source
Changed in version 3.0: It accepts options parameter to control schema inferring.
Examples
>>> df = spark.range(1)>>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()[Row(json='STRUCT<`a`: BIGINT>')]>>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'})>>> df.select(schema.alias("json")).collect()[Row(json='STRUCT<`a`: BIGINT>')]
3.1. Spark SQL 219
pyspark Documentation, Release master
pyspark.sql.functions.second
pyspark.sql.functions.second(col)Extract the seconds of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])>>> df.select(second('ts').alias('second')).collect()[Row(second=15)]
pyspark.sql.functions.sequence
pyspark.sql.functions.sequence(start, stop, step=None)Generate a sequence of integers from start to stop, incrementing by step. If step is not set, incrementing by 1 ifstart is less than or equal to stop, otherwise -1.
New in version 2.4.0.
Examples
>>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2'))>>> df1.select(sequence('C1', 'C2').alias('r')).collect()[Row(r=[-2, -1, 0, 1, 2])]>>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3'))>>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect()[Row(r=[4, 2, 0, -2, -4])]
pyspark.sql.functions.sha1
pyspark.sql.functions.sha1(col)Returns the hex string result of SHA-1.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).→˓collect()[Row(hash='3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
220 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.sha2
pyspark.sql.functions.sha2(col, numBits)Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, and SHA-512).The numBits indicates the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0(which is equivalent to 256).
New in version 1.5.0.
Examples
>>> digests = df.select(sha2(df.name, 256).alias('s')).collect()>>> digests[0]Row(s='3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')>>> digests[1]Row(s='cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
pyspark.sql.functions.shiftLeft
pyspark.sql.functions.shiftLeft(col, numBits)Shift the given value numBits left.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).→˓collect()[Row(r=42)]
pyspark.sql.functions.shiftRight
pyspark.sql.functions.shiftRight(col, numBits)(Signed) shift the given value numBits right.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).→˓collect()[Row(r=21)]
3.1. Spark SQL 221
pyspark Documentation, Release master
pyspark.sql.functions.shiftRightUnsigned
pyspark.sql.functions.shiftRightUnsigned(col, numBits)Unsigned shift the given value numBits right.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([(-42,)], ['a'])>>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()[Row(r=9223372036854775787)]
pyspark.sql.functions.shuffle
pyspark.sql.functions.shuffle(col)Collection function: Generates a random permutation of the given array.
New in version 2.4.0.
Parameters
col [Column or str] name of column or expression
Notes
The function is non-deterministic.
Examples
>>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data'])>>> df.select(shuffle(df.data).alias('s')).collect()[Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])]
pyspark.sql.functions.signum
pyspark.sql.functions.signum(col)Computes the signum of the given value.
New in version 1.4.
pyspark.sql.functions.sin
pyspark.sql.functions.sin(col)New in version 1.4.0.
Parameters
col [Column or str]
Returns
Column sine of the angle, as if computed by java.lang.Math.sin()
222 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.sinh
pyspark.sql.functions.sinh(col)New in version 1.4.0.
Parameters
col [Column or str] hyperbolic angle
Returns
Column hyperbolic sine of the given value, as if computed by java.lang.Math.sinh()
pyspark.sql.functions.size
pyspark.sql.functions.size(col)Collection function: returns the length of the array or map stored in the column.
New in version 1.5.0.
Parameters
col [Column or str] name of column or expression
Examples
>>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])>>> df.select(size(df.data)).collect()[Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
pyspark.sql.functions.skewness
pyspark.sql.functions.skewness(col)Aggregate function: returns the skewness of the values in a group.
New in version 1.6.
pyspark.sql.functions.slice
pyspark.sql.functions.slice(x, start, length)Collection function: returns an array containing all the elements in x from index start (array indices start at 1,or from the end if start is negative) with the specified length.
New in version 2.4.0.
Parameters
x [Column or str] the array to be sliced
start [Column or int] the starting index
length [Column or int] the length of the slice
3.1. Spark SQL 223
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])>>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()[Row(sliced=[2, 3]), Row(sliced=[5])]
pyspark.sql.functions.sort_array
pyspark.sql.functions.sort_array(col, asc=True)Collection function: sorts the input array in ascending or descending order according to the natural ordering ofthe array elements. Null elements will be placed at the beginning of the returned array in ascending order or atthe end of the returned array in descending order.
New in version 1.5.0.
Parameters
col [Column or str] name of column or expression
asc [bool, optional]
Examples
>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])>>> df.select(sort_array(df.data).alias('r')).collect()[Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()[Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
pyspark.sql.functions.soundex
pyspark.sql.functions.soundex(col)Returns the SoundEx encoding for a string
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name'])>>> df.select(soundex(df.name).alias("soundex")).collect()[Row(soundex='P362'), Row(soundex='U612')]
pyspark.sql.functions.spark_partition_id
pyspark.sql.functions.spark_partition_id()A column for partition ID.
New in version 1.6.0.
224 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This is non deterministic because it depends on data partitioning and task scheduling.
Examples
>>> df.repartition(1).select(spark_partition_id().alias("pid")).collect()[Row(pid=0), Row(pid=0)]
pyspark.sql.functions.split
pyspark.sql.functions.split(str, pattern, limit=- 1)Splits str around matches of the given pattern.
New in version 1.5.0.
Parameters
str [Column or str] a string expression to split
pattern [str] a string representing a regular expression. The regex string should be a Java regularexpression.
limit [int, optional] an integer which controls the number of times pattern is applied.
• limit > 0: The resulting array’s length will not be more than limit, and theresulting array’s last entry will contain all input beyond the last matched pattern.
• limit <= 0: pattern will be applied as many times as possible, and the resultingarray can be of any size.
Changed in version 3.0: split now takes an optional limit field. If not provided, default limitvalue is -1.
Examples
>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])>>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()[Row(s=['one', 'twoBthreeC'])]>>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()[Row(s=['one', 'two', 'three', ''])]
pyspark.sql.functions.sqrt
pyspark.sql.functions.sqrt(col)Computes the square root of the specified float value.
New in version 1.3.
3.1. Spark SQL 225
pyspark Documentation, Release master
pyspark.sql.functions.stddev
pyspark.sql.functions.stddev(col)Aggregate function: alias for stddev_samp.
New in version 1.6.
pyspark.sql.functions.stddev_pop
pyspark.sql.functions.stddev_pop(col)Aggregate function: returns population standard deviation of the expression in a group.
New in version 1.6.
pyspark.sql.functions.stddev_samp
pyspark.sql.functions.stddev_samp(col)Aggregate function: returns the unbiased sample standard deviation of the expression in a group.
New in version 1.6.
pyspark.sql.functions.struct
pyspark.sql.functions.struct(*cols)Creates a new struct column.
New in version 1.4.0.
Parameters
cols [list, set, str or Column] column names or Columns to contain in the output struct.
Examples
>>> df.select(struct('age', 'name').alias("struct")).collect()[Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))]>>> df.select(struct([df.age, df.name]).alias("struct")).collect()[Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))]
pyspark.sql.functions.substring
pyspark.sql.functions.substring(str, pos, len)Substring starts at pos and is of length len when str is String type or returns the slice of byte array that starts atpos in byte and is of length len when str is Binary type.
New in version 1.5.0.
226 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
The position is not zero based, but 1 based index.
Examples
>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(substring(df.s, 1, 2).alias('s')).collect()[Row(s='ab')]
pyspark.sql.functions.substring_index
pyspark.sql.functions.substring_index(str, delim, count)Returns the substring from string str before count occurrences of the delimiter delim. If count is positive,everything the left of the final delimiter (counting from left) is returned. If count is negative, every to the right ofthe final delimiter (counting from the right) is returned. substring_index performs a case-sensitive match whensearching for delim.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('a.b.c.d',)], ['s'])>>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()[Row(s='a.b')]>>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()[Row(s='b.c.d')]
pyspark.sql.functions.sum
pyspark.sql.functions.sum(col)Aggregate function: returns the sum of all values in the expression.
New in version 1.3.
pyspark.sql.functions.sumDistinct
pyspark.sql.functions.sumDistinct(col)Aggregate function: returns the sum of distinct values in the expression.
New in version 1.3.
3.1. Spark SQL 227
pyspark Documentation, Release master
pyspark.sql.functions.tan
pyspark.sql.functions.tan(col)New in version 1.4.0.
Parameters
col [Column or str] angle in radians
Returns
Column tangent of the given value, as if computed by java.lang.Math.tan()
pyspark.sql.functions.tanh
pyspark.sql.functions.tanh(col)New in version 1.4.0.
Parameters
col [Column or str] hyperbolic angle
Returns
Column hyperbolic tangent of the given value as if computed by java.lang.Math.tanh()
pyspark.sql.functions.timestamp_seconds
pyspark.sql.functions.timestamp_seconds(col)New in version 3.1.0.
Examples
>>> from pyspark.sql.functions import timestamp_seconds>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")>>> time_df = spark.createDataFrame([(1230219000,)], ['unix_time'])>>> time_df.select(timestamp_seconds(time_df.unix_time).alias('ts')).show()+-------------------+| ts|+-------------------+|2008-12-25 07:30:00|+-------------------+>>> spark.conf.unset("spark.sql.session.timeZone")
pyspark.sql.functions.toDegrees
pyspark.sql.functions.toDegrees(col)Deprecated since version 2.1.0: Use degrees() instead.
New in version 1.4.
228 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.toRadians
pyspark.sql.functions.toRadians(col)Deprecated since version 2.1.0: Use radians() instead.
New in version 1.4.
pyspark.sql.functions.to_csv
pyspark.sql.functions.to_csv(col, options=None)Converts a column containing a StructType into a CSV string. Throws an exception, in the case of anunsupported type.
New in version 3.0.0.
Parameters
col [Column or str] name of column containing a struct.
options: dict, optional options to control converting. accepts the same options as the CSVdatasource.
Examples
>>> from pyspark.sql import Row>>> data = [(1, Row(age=2, name='Alice'))]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_csv(df.value).alias("csv")).collect()[Row(csv='2,Alice')]
pyspark.sql.functions.to_date
pyspark.sql.functions.to_date(col, format=None)Converts a Column into pyspark.sql.types.DateType using the optionally specified format. Spec-ify formats according to datetime pattern. By default, it follows casting rules to pyspark.sql.types.DateType if the format is omitted. Equivalent to col.cast("date").
New in version 2.2.0.
Examples
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_date(df.t).alias('date')).collect()[Row(date=datetime.date(1997, 2, 28))]
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect()[Row(date=datetime.date(1997, 2, 28))]
3.1. Spark SQL 229
pyspark Documentation, Release master
pyspark.sql.functions.to_json
pyspark.sql.functions.to_json(col, options=None)Converts a column containing a StructType, ArrayType or a MapType into a JSON string. Throws anexception, in the case of an unsupported type.
New in version 2.1.0.
Parameters
col [Column or str] name of column containing a struct, an array or a map.
options [dict, optional] options to control converting. accepts the same options as the JSONdatasource. Additionally the function supports the pretty option which enables pretty JSONgeneration.
Examples
>>> from pyspark.sql import Row>>> from pyspark.sql.types import *>>> data = [(1, Row(age=2, name='Alice'))]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='{"age":2,"name":"Alice"}')]>>> data = [(1, [Row(age=2, name='Alice'), Row(age=3, name='Bob')])]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')]>>> data = [(1, {"name": "Alice"})]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='{"name":"Alice"}')]>>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='[{"name":"Alice"},{"name":"Bob"}]')]>>> data = [(1, ["Alice", "Bob"])]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='["Alice","Bob"]')]
pyspark.sql.functions.to_timestamp
pyspark.sql.functions.to_timestamp(col, format=None)Converts a Column into pyspark.sql.types.TimestampType using the optionally specified format.Specify formats according to datetime pattern. By default, it follows casting rules to pyspark.sql.types.TimestampType if the format is omitted. Equivalent to col.cast("timestamp").
New in version 2.2.0.
230 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_timestamp(df.t).alias('dt')).collect()[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect()[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
pyspark.sql.functions.to_utc_timestamp
pyspark.sql.functions.to_utc_timestamp(timestamp, tz)This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function takesa timestamp which is timezone-agnostic, and interprets it as a timestamp in the given timezone, and renders thattimestamp as a timestamp in UTC.
However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not timezone-agnostic. So in Spark this function just shift the timestamp value from the given timezone to UTC timezone.
This function may return confusing result if the input is a string with timezone, e.g. ‘2018-03-13T06:18:23+00:00’. The reason is that, Spark firstly cast the string to timestamp according to the timezonein the string, and finally display the result by converting the timestamp to string according to the session localtimezone.
New in version 1.5.0.
Parameters
timestamp [Column or str] the column that contains timestamps
tz [Column or str] A string detailing the time zone ID that the input should be adjusted to. Itshould be in the format of either region-based zone IDs or zone offsets. Region IDs musthave the form ‘area/city’, such as ‘America/Los_Angeles’. Zone offsets must be in theformat ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also ‘UTC’ and ‘Z’ are upportedas aliases of ‘+00:00’. Other short names are not recommended to use because they can beambiguous.
Changed in version 2.4.0: tz can take a Column containing timezone ID strings.
Examples
>>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])>>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect()[Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))]>>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect()[Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))]
3.1. Spark SQL 231
pyspark Documentation, Release master
pyspark.sql.functions.transform
pyspark.sql.functions.transform(col, f)Returns an array of elements after applying a transformation to each element in the input array.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
f [function] a function that is applied to each element of the input array. Can take one of thefollowing forms:
• Unary (x: Column) -> Column: ...
• Binary (x: Column, i: Column) -> Column..., where the second argument isa 0-based index of the element.
and can use methods of pyspark.sql.Column, functions defined inpyspark.sql.functions and Scala UserDefinedFunctions. PythonUserDefinedFunctions are not supported (SPARK-27052).
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values"))>>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show()+------------+| doubled|+------------+|[2, 4, 6, 8]|+------------+
>>> def alternate(x, i):... return when(i % 2 == 0, x).otherwise(-x)>>> df.select(transform("values", alternate).alias("alternated")).show()+--------------+| alternated|+--------------+|[1, -2, 3, -4]|+--------------+
pyspark.sql.functions.transform_keys
pyspark.sql.functions.transform_keys(col, f)Applies a function to every key-value pair in a map and returns a map with the results of those applications asthe new keys for the pairs.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
232 Chapter 3. API Reference
pyspark Documentation, Release master
f [function] a binary function (k: Column, v: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data"))>>> df.select(transform_keys(... "data", lambda k, _: upper(k)).alias("data_upper")... ).show(truncate=False)+-------------------------+|data_upper |+-------------------------+|{BAR -> 2.0, FOO -> -2.0}|+-------------------------+
pyspark.sql.functions.transform_values
pyspark.sql.functions.transform_values(col, f)Applies a function to every key-value pair in a map and returns a map with the results of those applications asthe new values for the pairs.
New in version 3.1.0.
Parameters
col [Column or str] name of column or expression
f [function] a binary function (k: Column, v: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).
Returns
pyspark.sql.Column
Examples
>>> df = spark.createDataFrame([(1, {"IT": 10.0, "SALES": 2.0, "OPS": 24.0})], (→˓"id", "data"))>>> df.select(transform_values(... "data", lambda k, v: when(k.isin("IT", "OPS"), v + 10.0).otherwise(v)... ).alias("new_data")).show(truncate=False)+---------------------------------------+|new_data |+---------------------------------------+|{OPS -> 34.0, IT -> 20.0, SALES -> 2.0}|+---------------------------------------+
3.1. Spark SQL 233
pyspark Documentation, Release master
pyspark.sql.functions.translate
pyspark.sql.functions.translate(srcCol, matching, replace)A function translate any character in the srcCol by a character in matching. The characters in replace is cor-responding to the characters in matching. The translate will happen when any character in the string matchingwith the character in the matching.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt",→˓"123") \... .alias('r')).collect()[Row(r='1a2s3ae')]
pyspark.sql.functions.trim
pyspark.sql.functions.trim(col)Trim the spaces from both ends for the specified string column.
New in version 1.5.
pyspark.sql.functions.trunc
pyspark.sql.functions.trunc(date, format)Returns date truncated to the unit specified by the format.
New in version 1.5.0.
Parameters
date [Column or str]
format [str] ‘year’, ‘yyyy’, ‘yy’ or ‘month’, ‘mon’, ‘mm’
Examples
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])>>> df.select(trunc(df.d, 'year').alias('year')).collect()[Row(year=datetime.date(1997, 1, 1))]>>> df.select(trunc(df.d, 'mon').alias('month')).collect()[Row(month=datetime.date(1997, 2, 1))]
234 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.udf
pyspark.sql.functions.udf(f=None, returnType=StringType)Creates a user defined function (UDF).
New in version 1.3.0.
Parameters
f [function] python function if used as a standalone function
returnType [pyspark.sql.types.DataType or str] the return type of the user-definedfunction. The value can be either a pyspark.sql.types.DataType object or a DDL-formatted type string.
Notes
The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocationsmay be eliminated or the function may even be invoked more times than it is present in the query. If yourfunction is not deterministic, call asNondeterministic on the user defined function. E.g.:
>>> from pyspark.sql.types import IntegerType>>> import random>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).→˓asNondeterministic()
The user-defined functions do not support conditional expressions or short circuiting in boolean expressions andit ends up with being executed all internally. If the functions can fail on special rows, the workaround is toincorporate the condition into the functions.
The user-defined functions do not take keyword arguments on the calling side.
Examples
>>> from pyspark.sql.types import IntegerType>>> slen = udf(lambda s: len(s), IntegerType())>>> @udf... def to_upper(s):... if s is not None:... return s.upper()...>>> @udf(returnType=IntegerType())... def add_one(x):... if x is not None:... return x + 1...>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).→˓show()+----------+--------------+------------+|slen(name)|to_upper(name)|add_one(age)|+----------+--------------+------------+| 8| JOHN DOE| 22|+----------+--------------+------------+
3.1. Spark SQL 235
pyspark Documentation, Release master
pyspark.sql.functions.unbase64
pyspark.sql.functions.unbase64(col)Decodes a BASE64 encoded string column and returns it as a binary column.
New in version 1.5.
pyspark.sql.functions.unhex
pyspark.sql.functions.unhex(col)Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representationof number.
New in version 1.5.0.
Examples
>>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()[Row(unhex(a)=bytearray(b'ABC'))]
pyspark.sql.functions.unix_timestamp
pyspark.sql.functions.unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss')Convert time string with given pattern (‘yyyy-MM-dd HH:mm:ss’, by default) to Unix time stamp (in seconds),using the default timezone and the default locale, return null if fail.
if timestamp is None, then it returns current timestamp.
New in version 1.5.0.
Examples
>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")>>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).→˓collect()[Row(unix_time=1428476400)]>>> spark.conf.unset("spark.sql.session.timeZone")
pyspark.sql.functions.upper
pyspark.sql.functions.upper(col)Converts a string expression to upper case.
New in version 1.5.
236 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.var_pop
pyspark.sql.functions.var_pop(col)Aggregate function: returns the population variance of the values in a group.
New in version 1.6.
pyspark.sql.functions.var_samp
pyspark.sql.functions.var_samp(col)Aggregate function: returns the unbiased sample variance of the values in a group.
New in version 1.6.
pyspark.sql.functions.variance
pyspark.sql.functions.variance(col)Aggregate function: alias for var_samp
New in version 1.6.
pyspark.sql.functions.weekofyear
pyspark.sql.functions.weekofyear(col)Extract the week number of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(weekofyear(df.dt).alias('week')).collect()[Row(week=15)]
pyspark.sql.functions.when
pyspark.sql.functions.when(condition, value)Evaluates a list of conditions and returns one of multiple possible result expressions. If Column.otherwise() is not invoked, None is returned for unmatched conditions.
New in version 1.4.0.
Parameters
condition [Column] a boolean Column expression.
value : a literal value, or a Column expression.
>>> df.select(when(df[‘age’] == 2, 3).otherwise(4).alias(“age”)).collect()
[Row(age=3), Row(age=4)]
>>> df.select(when(df.age == 2, df.age + 1).alias(“age”)).collect()
[Row(age=3), Row(age=None)]
3.1. Spark SQL 237
pyspark Documentation, Release master
pyspark.sql.functions.window
pyspark.sql.functions.window(timeColumn, windowDuration, slideDuration=None, start-Time=None)
Bucketize rows into one or more time windows given a timestamp specifying column. Window starts are inclu-sive but the window ends are exclusive, e.g. 12:05 will be in the window [12:05,12:10) but not in [12:00,12:05).Windows can support microsecond precision. Windows in the order of months are not supported.
The time column must be of pyspark.sql.types.TimestampType.
Durations are provided as strings, e.g. ‘1 second’, ‘1 day 12 hours’, ‘2 minutes’. Valid interval strings are ‘week’,‘day’, ‘hour’, ‘minute’, ‘second’, ‘millisecond’, ‘microsecond’. If the slideDuration is not provided, thewindows will be tumbling windows.
The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start window intervals.For example, in order to have hourly tumbling windows that start 15 minutes past the hour, e.g. 12:15-13:15,13:15-14:15. . . provide startTime as 15 minutes.
The output column will be a struct called ‘window’ by default with the nested columns ‘start’ and ‘end’, where‘start’ and ‘end’ will be of pyspark.sql.types.TimestampType.
New in version 2.0.0.
Examples
>>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val")>>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum"))>>> w.select(w.window.start.cast("string").alias("start"),... w.window.end.cast("string").alias("end"), "sum").collect()[Row(start='2016-03-11 09:00:05', end='2016-03-11 09:00:10', sum=1)]
pyspark.sql.functions.xxhash64
pyspark.sql.functions.xxhash64(*cols)Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm, and returns theresult as a long column.
New in version 3.0.0.
Examples
>>> spark.createDataFrame([('ABC',)], ['a']).select(xxhash64('a').alias('hash')).→˓collect()[Row(hash=4105715581806190027)]
238 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.functions.year
pyspark.sql.functions.year(col)Extract the year of a given date as integer.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(year('dt').alias('year')).collect()[Row(year=2015)]
pyspark.sql.functions.years
pyspark.sql.functions.years(col)Partition transform function: A transform for timestamps and dates to partition data into years.
New in version 3.1.0.
Notes
This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.
Examples
>>> df.writeTo("catalog.db.table").partitionedBy(... years("ts")... ).createOrReplace()
pyspark.sql.functions.zip_with
pyspark.sql.functions.zip_with(col1, col2, f)Merge two given arrays, element-wise, into a single array using a function. If one array is shorter, nulls areappended at the end to match the length of the longer array, before applying the function.
New in version 3.1.0.
Parameters
col1 [Column or str] name of the first column or expression
col2 [Column or str] name of the second column or expression
f [function] a binary function (x1: Column, x2: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).
Returns
pyspark.sql.Column
3.1. Spark SQL 239
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([(1, [1, 3, 5, 8], [0, 2, 4, 6])], ("id", "xs", "ys→˓"))>>> df.select(zip_with("xs", "ys", lambda x, y: x ** y).alias("powers")).→˓show(truncate=False)+---------------------------+|powers |+---------------------------+|[1.0, 9.0, 625.0, 262144.0]|+---------------------------+
>>> df = spark.createDataFrame([(1, ["foo", "bar"], [1, 2, 3])], ("id", "xs", "ys→˓"))>>> df.select(zip_with("xs", "ys", lambda x, y: concat_ws("_", x, y)).alias("xs_ys→˓")).show()+-----------------+| xs_ys|+-----------------+|[foo_1, bar_2, 3]|+-----------------+
from_avro(data, jsonFormatSchema[, options]) Converts a binary column of Avro format into its corre-sponding catalyst value.
to_avro(data[, jsonFormatSchema]) Converts a column into binary of avro format.
pyspark.sql.avro.functions.from_avro
pyspark.sql.avro.functions.from_avro(data, jsonFormatSchema, options=None)Converts a binary column of Avro format into its corresponding catalyst value. The specified schema mustmatch the read data, otherwise the behavior is undefined: it may fail or return arbitrary result. To deserialize thedata with a compatible and evolved schema, the expected Avro schema can be set via the option avroSchema.
New in version 3.0.0.
Parameters
data [Column or str] the binary column.
jsonFormatSchema [str] the avro schema in JSON string format.
options [dict, optional] options to control how the Avro record is parsed.
Notes
Avro is built-in but external data source module since Spark 2.4. Please deploy the application as per thedeployment section of “Apache Avro Data Source Guide”.
240 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.sql import Row>>> from pyspark.sql.avro.functions import from_avro, to_avro>>> data = [(1, Row(age=2, name='Alice'))]>>> df = spark.createDataFrame(data, ("key", "value"))>>> avroDf = df.select(to_avro(df.value).alias("avro"))>>> avroDf.collect()[Row(avro=bytearray(b'\x00\x00\x04\x00\nAlice'))]
>>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields":... [{"name":"avro","type":[{"type":"record","name":"value","namespace":→˓"topLevelRecord",... "fields":[{"name":"age","type":["long","null"]},... {"name":"name","type":["string","null"]}]},"null"]}]}'''>>> avroDf.select(from_avro(avroDf.avro, jsonFormatSchema).alias("value")).→˓collect()[Row(value=Row(avro=Row(age=2, name='Alice')))]
pyspark.sql.avro.functions.to_avro
pyspark.sql.avro.functions.to_avro(data, jsonFormatSchema='')Converts a column into binary of avro format.
New in version 3.0.0.
Parameters
data [Column or str] the data column.
jsonFormatSchema [str, optional] user-specified output avro schema in JSON string format.
Notes
Avro is built-in but external data source module since Spark 2.4. Please deploy the application as per thedeployment section of “Apache Avro Data Source Guide”.
Examples
>>> from pyspark.sql import Row>>> from pyspark.sql.avro.functions import to_avro>>> data = ['SPADES']>>> df = spark.createDataFrame(data, "string")>>> df.select(to_avro(df.value).alias("suite")).collect()[Row(suite=bytearray(b'\x00\x0cSPADES'))]
>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()[Row(suite=bytearray(b'\x02\x00'))]
3.1. Spark SQL 241
pyspark Documentation, Release master
3.1.9 Window
Window.currentRowWindow.orderBy(*cols) Creates a WindowSpec with the ordering defined.Window.partitionBy(*cols) Creates a WindowSpec with the partitioning defined.Window.rangeBetween(start, end) Creates a WindowSpec with the frame boundaries de-
fined, from start (inclusive) to end (inclusive).Window.rowsBetween(start, end) Creates a WindowSpec with the frame boundaries de-
fined, from start (inclusive) to end (inclusive).Window.unboundedFollowingWindow.unboundedPrecedingWindowSpec.orderBy(*cols) Defines the ordering columns in a WindowSpec.WindowSpec.partitionBy(*cols) Defines the partitioning columns in a WindowSpec.WindowSpec.rangeBetween(start, end) Defines the frame boundaries, from start (inclusive) to
end (inclusive).WindowSpec.rowsBetween(start, end) Defines the frame boundaries, from start (inclusive) to
end (inclusive).
pyspark.sql.Window.currentRow
Window.currentRow = 0
pyspark.sql.Window.orderBy
static Window.orderBy(*cols)Creates a WindowSpec with the ordering defined.
New in version 1.4.
pyspark.sql.Window.partitionBy
static Window.partitionBy(*cols)Creates a WindowSpec with the partitioning defined.
New in version 1.4.
pyspark.sql.Window.rangeBetween
static Window.rangeBetween(start, end)Creates a WindowSpec with the frame boundaries defined, from start (inclusive) to end (inclusive).
Both start and end are relative from the current row. For example, “0” means “current row”, while “-1” meansone off before the current row, and “5” means the five off after the current row.
We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.
A range-based boundary is based on the actual value of the ORDER BY expression(s). An offset is used to alterthe value of the ORDER BY expression, for instance if the current ORDER BY expression has a value of 10 andthe lower bound offset is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however putsa number of constraints on the ORDER BY expressions: there can be only one expression and this expression
242 Chapter 3. API Reference
pyspark Documentation, Release master
must have a numerical data type. An exception can be made when the offset is unbounded, because no valuemodification is needed, in this case multiple and non-numeric ORDER BY expression are allowed.
New in version 2.1.0.
Parameters
start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to max(-sys.maxsize, -9223372036854775808).
end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to min(sys.maxsize,9223372036854775807).
Examples
>>> from pyspark.sql import Window>>> from pyspark.sql import functions as func>>> from pyspark.sql import SQLContext>>> sc = SparkContext.getOrCreate()>>> sqlContext = SQLContext(sc)>>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")]>>> df = sqlContext.createDataFrame(tup, ["id", "category"])>>> window = Window.partitionBy("category").orderBy("id").rangeBetween(Window.→˓currentRow, 1)>>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category").→˓show()+---+--------+---+| id|category|sum|+---+--------+---+| 1| a| 4|| 1| a| 4|| 1| b| 3|| 2| a| 2|| 2| b| 5|| 3| b| 3|+---+--------+---+
pyspark.sql.Window.rowsBetween
static Window.rowsBetween(start, end)Creates a WindowSpec with the frame boundaries defined, from start (inclusive) to end (inclusive).
Both start and end are relative positions from the current row. For example, “0” means “current row”, while“-1” means the row before the current row, and “5” means the fifth row after the current row.
We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.
A row based boundary is based on the position of the row within the partition. An offset indicates the numberof rows above or below the current row, the frame for the current row starts or ends. For instance, given a rowbased sliding frame with a lower bound offset of -1 and a upper bound offset of +2. The frame for row withindex 5 would range from index 4 to index 7.
New in version 2.1.0.
Parameters
3.1. Spark SQL 243
pyspark Documentation, Release master
start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to -9223372036854775808.
end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to 9223372036854775807.
Examples
>>> from pyspark.sql import Window>>> from pyspark.sql import functions as func>>> from pyspark.sql import SQLContext>>> sc = SparkContext.getOrCreate()>>> sqlContext = SQLContext(sc)>>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")]>>> df = sqlContext.createDataFrame(tup, ["id", "category"])>>> window = Window.partitionBy("category").orderBy("id").rowsBetween(Window.→˓currentRow, 1)>>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category", "sum→˓").show()+---+--------+---+| id|category|sum|+---+--------+---+| 1| a| 2|| 1| a| 3|| 1| b| 3|| 2| a| 2|| 2| b| 5|| 3| b| 3|+---+--------+---+
pyspark.sql.Window.unboundedFollowing
Window.unboundedFollowing = 9223372036854775807
pyspark.sql.Window.unboundedPreceding
Window.unboundedPreceding = -9223372036854775808
pyspark.sql.WindowSpec.orderBy
WindowSpec.orderBy(*cols)Defines the ordering columns in a WindowSpec.
New in version 1.4.0.
Parameters
cols [str, Column or list] names of columns or expressions
244 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.WindowSpec.partitionBy
WindowSpec.partitionBy(*cols)Defines the partitioning columns in a WindowSpec.
New in version 1.4.0.
Parameters
cols [str, Column or list] names of columns or expressions
pyspark.sql.WindowSpec.rangeBetween
WindowSpec.rangeBetween(start, end)Defines the frame boundaries, from start (inclusive) to end (inclusive).
Both start and end are relative from the current row. For example, “0” means “current row”, while “-1” meansone off before the current row, and “5” means the five off after the current row.
We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.
New in version 1.4.0.
Parameters
start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to max(-sys.maxsize, -9223372036854775808).
end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to min(sys.maxsize,9223372036854775807).
pyspark.sql.WindowSpec.rowsBetween
WindowSpec.rowsBetween(start, end)Defines the frame boundaries, from start (inclusive) to end (inclusive).
Both start and end are relative positions from the current row. For example, “0” means “current row”, while“-1” means the row before the current row, and “5” means the fifth row after the current row.
We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.
New in version 1.4.0.
Parameters
start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to max(-sys.maxsize, -9223372036854775808).
end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to min(sys.maxsize,9223372036854775807).
3.1. Spark SQL 245
pyspark Documentation, Release master
3.1.10 Grouping
GroupedData.agg(*exprs) Compute aggregates and returns the result as aDataFrame.
GroupedData.apply(udf) It is an alias of pyspark.sql.GroupedData.applyInPandas(); however, it takes apyspark.sql.functions.pandas_udf()whereas pyspark.sql.GroupedData.applyInPandas() takes a Python native function.
GroupedData.applyInPandas(func, schema) Maps each group of the current DataFrame using apandas udf and returns the result as a DataFrame.
GroupedData.avg(*cols) Computes average values for each numeric columns foreach group.
GroupedData.cogroup(other) Cogroups this group with another group so that we canrun cogrouped operations.
GroupedData.count() Counts the number of records for each group.GroupedData.max(*cols) Computes the max value for each numeric columns for
each group.GroupedData.mean(*cols) Computes average values for each numeric columns for
each group.GroupedData.min(*cols) Computes the min value for each numeric column for
each group.GroupedData.pivot(pivot_col[, values]) Pivots a column of the current DataFrame and per-
form the specified aggregation.GroupedData.sum(*cols) Compute the sum for each numeric columns for each
group.PandasCogroupedOps.applyInPandas(func,schema)
Applies a function to each cogroup using pandas andreturns the result as a DataFrame.
pyspark.sql.GroupedData.agg
GroupedData.agg(*exprs)Compute aggregates and returns the result as a DataFrame.
The available aggregate functions can be:
1. built-in aggregation functions, such as avg, max, min, sum, count
2. group aggregate pandas UDFs, created with pyspark.sql.functions.pandas_udf()
Note: There is no partial aggregation with group aggregate UDFs, i.e., a full shuffle is required. Also, allthe data of a group will be loaded into memory, so the user should be aware of the potential OOM risk ifdata is skewed and certain groups are too large to fit in memory.
See also:
pyspark.sql.functions.pandas_udf()
If exprs is a single dict mapping from string to string, then the key is the column to perform aggregation on,and the value is the aggregate function.
Alternatively, exprs can also be a list of aggregate Column expressions.
New in version 1.3.0.
246 Chapter 3. API Reference
pyspark Documentation, Release master
Parameters
exprs [dict] a dict mapping from column name (string) to aggregate functions (string), or a listof Column.
Notes
Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed in a single call to this func-tion.
Examples
>>> gdf = df.groupBy(df.name)>>> sorted(gdf.agg({"*": "count"}).collect())[Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=1)]
>>> from pyspark.sql import functions as F>>> sorted(gdf.agg(F.min(df.age)).collect())[Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=5)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType>>> @pandas_udf('int', PandasUDFType.GROUPED_AGG)... def min_udf(v):... return v.min()>>> sorted(gdf.agg(min_udf(df.age)).collect())[Row(name='Alice', min_udf(age)=2), Row(name='Bob', min_udf(age)=5)]
pyspark.sql.GroupedData.apply
GroupedData.apply(udf)It is an alias of pyspark.sql.GroupedData.applyInPandas(); however, it takes a pyspark.sql.functions.pandas_udf() whereas pyspark.sql.GroupedData.applyInPandas() takes aPython native function.
New in version 2.3.0.
Parameters
udf [pyspark.sql.functions.pandas_udf()] a grouped map user-defined functionreturned by pyspark.sql.functions.pandas_udf().
See also:
pyspark.sql.functions.pandas_udf
3.1. Spark SQL 247
pyspark Documentation, Release master
Notes
It is preferred to use pyspark.sql.GroupedData.applyInPandas() over this API. This API will bedeprecated in the future releases.
Examples
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],... ("id", "v"))>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)... def normalize(pdf):... v = pdf.v... return pdf.assign(v=(v - v.mean()) / v.std())>>> df.groupby("id").apply(normalize).show()+---+-------------------+| id| v|+---+-------------------+| 1|-0.7071067811865475|| 1| 0.7071067811865475|| 2|-0.8320502943378437|| 2|-0.2773500981126146|| 2| 1.1094003924504583|+---+-------------------+
pyspark.sql.GroupedData.applyInPandas
GroupedData.applyInPandas(func, schema)Maps each group of the current DataFrame using a pandas udf and returns the result as a DataFrame.
The function should take a pandas.DataFrame and return another pandas.DataFrame. For each group, allcolumns are passed together as a pandas.DataFrame to the user-function and the returned pandas.DataFrameare combined as a DataFrame.
The schema should be a StructType describing the schema of the returned pandas.DataFrame. The columnlabels of the returned pandas.DataFrame must either match the field names in the defined schema if specifiedas strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returnedpandas.DataFrame can be arbitrary.
New in version 3.0.0.
Parameters
func [function] a Python native function that takes a pandas.DataFrame, and outputs a pan-das.DataFrame.
schema [pyspark.sql.types.DataType or str] the return type of the func in PySpark.The value can be either a pyspark.sql.types.DataType object or a DDL-formattedtype string.
See also:
pyspark.sql.functions.pandas_udf
248 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This function requires a full shuffle. All the data of a group will be loaded into memory, so the user should beaware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.
If returning a new pandas.DataFrame constructed with a dictionary, it is recommended to explicitly indexthe columns by name to ensure the positions are correct, or alternatively use an OrderedDict. For exam-ple, pd.DataFrame({‘id’: ids, ‘a’: data}, columns=[‘id’, ‘a’]) or pd.DataFrame(OrderedDict([(‘id’, ids), (‘a’,data)])).
This API is experimental.
Examples
>>> import pandas as pd>>> from pyspark.sql.functions import pandas_udf, ceil>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],... ("id", "v"))>>> def normalize(pdf):... v = pdf.v... return pdf.assign(v=(v - v.mean()) / v.std())>>> df.groupby("id").applyInPandas(... normalize, schema="id long, v double").show()+---+-------------------+| id| v|+---+-------------------+| 1|-0.7071067811865475|| 1| 0.7071067811865475|| 2|-0.8320502943378437|| 2|-0.2773500981126146|| 2| 1.1094003924504583|+---+-------------------+
Alternatively, the user can pass a function that takes two arguments. In this case, the grouping key(s) will bepassed as the first argument and the data will be passed as the second argument. The grouping key(s) will bepassed as a tuple of numpy data types, e.g., numpy.int32 and numpy.float64. The data will still be passed in asa pandas.DataFrame containing all columns from the original Spark DataFrame. This is useful when the userdoes not want to hardcode grouping key(s) in the function.
>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],... ("id", "v"))>>> def mean_func(key, pdf):... # key is a tuple of one numpy.int64, which is the value... # of 'id' for the current group... return pd.DataFrame([key + (pdf.v.mean(),)])>>> df.groupby('id').applyInPandas(... mean_func, schema="id long, v double").show()+---+---+| id| v|+---+---+| 1|1.5|| 2|6.0|+---+---+
3.1. Spark SQL 249
pyspark Documentation, Release master
>>> def sum_func(key, pdf):... # key is a tuple of two numpy.int64s, which is the values... # of 'id' and 'ceil(df.v / 2)' for the current group... return pd.DataFrame([key + (pdf.v.sum(),)])>>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(... sum_func, schema="id long, `ceil(v / 2)` long, v double").show()+---+-----------+----+| id|ceil(v / 2)| v|+---+-----------+----+| 2| 5|10.0|| 1| 1| 3.0|| 2| 3| 5.0|| 2| 2| 3.0|+---+-----------+----+
pyspark.sql.GroupedData.avg
GroupedData.avg(*cols)Computes average values for each numeric columns for each group.
mean() is an alias for avg().
New in version 1.3.0.
Parameters
cols [str] column names. Non-numeric columns are ignored.
Examples
>>> df.groupBy().avg('age').collect()[Row(avg(age)=3.5)]>>> df3.groupBy().avg('age', 'height').collect()[Row(avg(age)=3.5, avg(height)=82.5)]
pyspark.sql.GroupedData.cogroup
GroupedData.cogroup(other)Cogroups this group with another group so that we can run cogrouped operations.
New in version 3.0.0.
See PandasCogroupedOps for the operations that can be run.
pyspark.sql.GroupedData.count
GroupedData.count()Counts the number of records for each group.
New in version 1.3.0.
250 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> sorted(df.groupBy(df.age).count().collect())[Row(age=2, count=1), Row(age=5, count=1)]
pyspark.sql.GroupedData.max
GroupedData.max(*cols)Computes the max value for each numeric columns for each group.
New in version 1.3.0.
Examples
>>> df.groupBy().max('age').collect()[Row(max(age)=5)]>>> df3.groupBy().max('age', 'height').collect()[Row(max(age)=5, max(height)=85)]
pyspark.sql.GroupedData.mean
GroupedData.mean(*cols)Computes average values for each numeric columns for each group.
mean() is an alias for avg().
New in version 1.3.0.
Parameters
cols [str] column names. Non-numeric columns are ignored.
Examples
>>> df.groupBy().mean('age').collect()[Row(avg(age)=3.5)]>>> df3.groupBy().mean('age', 'height').collect()[Row(avg(age)=3.5, avg(height)=82.5)]
pyspark.sql.GroupedData.min
GroupedData.min(*cols)Computes the min value for each numeric column for each group.
New in version 1.3.0.
Parameters
cols [str] column names. Non-numeric columns are ignored.
3.1. Spark SQL 251
pyspark Documentation, Release master
Examples
>>> df.groupBy().min('age').collect()[Row(min(age)=2)]>>> df3.groupBy().min('age', 'height').collect()[Row(min(age)=2, min(height)=80)]
pyspark.sql.GroupedData.pivot
GroupedData.pivot(pivot_col, values=None)Pivots a column of the current DataFrame and perform the specified aggregation. There are two versions ofpivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that doesnot. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct valuesinternally.
New in version 1.6.0.
Parameters
pivot_col [str] Name of the column to pivot.
values : List of values that will be translated to columns in the output DataFrame.
Examples
# Compute the sum of earnings for each year by course with each course as a separate column
>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").→˓collect()[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000,→˓Java=30000)]
# Or without specifying column values (less efficient)
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000,→˓dotNET=48000)]>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").→˓collect()[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000,→˓dotNET=48000)]
pyspark.sql.GroupedData.sum
GroupedData.sum(*cols)Compute the sum for each numeric columns for each group.
New in version 1.3.0.
Parameters
cols [str] column names. Non-numeric columns are ignored.
252 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> df.groupBy().sum('age').collect()[Row(sum(age)=7)]>>> df3.groupBy().sum('age', 'height').collect()[Row(sum(age)=7, sum(height)=165)]
pyspark.sql.PandasCogroupedOps.applyInPandas
PandasCogroupedOps.applyInPandas(func, schema)Applies a function to each cogroup using pandas and returns the result as a DataFrame.
The function should take two pandas.DataFrames and return another pandas.DataFrame. For each side ofthe cogroup, all columns are passed together as a pandas.DataFrame to the user-function and the returnedpandas.DataFrame are combined as a DataFrame.
The schema should be a StructType describing the schema of the returned pandas.DataFrame. The columnlabels of the returned pandas.DataFrame must either match the field names in the defined schema if specifiedas strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returnedpandas.DataFrame can be arbitrary.
New in version 3.0.0.
Parameters
func [function] a Python native function that takes two pandas.DataFrames, and outputs a pan-das.DataFrame, or that takes one tuple (grouping keys) and two pandas DataFrames, andoutputs a pandas DataFrame.
schema [pyspark.sql.types.DataType or str] the return type of the func in PySpark.The value can be either a pyspark.sql.types.DataType object or a DDL-formattedtype string.
See also:
pyspark.sql.functions.pandas_udf
Notes
This function requires a full shuffle. All the data of a cogroup will be loaded into memory, so the user should beaware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.
If returning a new pandas.DataFrame constructed with a dictionary, it is recommended to explicitly indexthe columns by name to ensure the positions are correct, or alternatively use an OrderedDict. For exam-ple, pd.DataFrame({‘id’: ids, ‘a’: data}, columns=[‘id’, ‘a’]) or pd.DataFrame(OrderedDict([(‘id’, ids), (‘a’,data)])).
This API is experimental.
3.1. Spark SQL 253
pyspark Documentation, Release master
Examples
>>> from pyspark.sql.functions import pandas_udf>>> df1 = spark.createDataFrame(... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2,→˓ 4.0)],... ("time", "id", "v1"))>>> df2 = spark.createDataFrame(... [(20000101, 1, "x"), (20000101, 2, "y")],... ("time", "id", "v2"))>>> def asof_join(l, r):... return pd.merge_asof(l, r, on="time", by="id")>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(... asof_join, schema="time int, id int, v1 double, v2 string"... ).show()+--------+---+---+---+| time| id| v1| v2|+--------+---+---+---+|20000101| 1|1.0| x||20000102| 1|3.0| x||20000101| 2|2.0| y||20000102| 2|4.0| y|+--------+---+---+---+
Alternatively, the user can define a function that takes three arguments. In this case, the grouping key(s) will bepassed as the first argument and the data will be passed as the second and third arguments. The grouping key(s)will be passed as a tuple of numpy data types, e.g., numpy.int32 and numpy.float64. The data will still be passedin as two pandas.DataFrame containing all columns from the original Spark DataFrames.
>>> def asof_join(k, l, r):... if k == (1,):... return pd.merge_asof(l, r, on="time", by="id")... else:... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(... asof_join, "time int, id int, v1 double, v2 string").show()+--------+---+---+---+| time| id| v1| v2|+--------+---+---+---+|20000101| 1|1.0| x||20000102| 1|3.0| x|+--------+---+---+---+
3.2 Structured Streaming
3.2.1 Core Classes
DataStreamReader(spark) Interface used to load a streaming DataFrame fromexternal storage systems (e.g.
DataStreamWriter(df) Interface used to write a streaming DataFrame to ex-ternal storage systems (e.g.
continues on next page
254 Chapter 3. API Reference
pyspark Documentation, Release master
Table 43 – continued from previous pageForeachBatchFunction(sql_ctx, func) This is the Python implementation of Java interface
‘ForeachBatchFunction’.StreamingQuery(jsq) A handle to a query that is executing continuously in the
background as new data arrives.StreamingQueryException(desc, stackTrace[,. . . ])
Exception that stopped a StreamingQuery .
StreamingQueryManager(jsqm) A class to manage all the StreamingQuery Stream-ingQueries active.
pyspark.sql.streaming.DataStreamReader
class pyspark.sql.streaming.DataStreamReader(spark)Interface used to load a streaming DataFrame from external storage systems (e.g. file systems, key-valuestores, etc). Use SparkSession.readStream to access this.
New in version 2.0.0.
Notes
This API is evolving.
Methods
csv(path[, schema, sep, encoding, quote, . . . ]) Loads a CSV file stream and returns the result as aDataFrame.
format(source) Specifies the input data source format.json(path[, schema, primitivesAsString, . . . ]) Loads a JSON file stream and returns the results as a
DataFrame.load([path, format, schema]) Loads a data stream from a data source and returns it
as a DataFrame.option(key, value) Adds an input option for the underlying data source.options(**options) Adds input options for the underlying data source.orc(path[, mergeSchema, pathGlobFilter, . . . ]) Loads a ORC file stream, returning the result as a
DataFrame.parquet(path[, mergeSchema, pathGlobFilter,. . . ])
Loads a Parquet file stream, returning the result as aDataFrame.
schema(schema) Specifies the input schema.text(path[, wholetext, lineSep, . . . ]) Loads a text file stream and returns a DataFrame
whose schema starts with a string column named“value”, and followed by partitioned columns if thereare any.
3.2. Structured Streaming 255
pyspark Documentation, Release master
pyspark.sql.streaming.DataStreamWriter
class pyspark.sql.streaming.DataStreamWriter(df)Interface used to write a streaming DataFrame to external storage systems (e.g. file systems, key-value stores,etc). Use DataFrame.writeStream to access this.
New in version 2.0.0.
Notes
This API is evolving.
Methods
foreach(f) Sets the output of the streaming query to be pro-cessed using the provided writer f.
foreachBatch(func) Sets the output of the streaming query to be pro-cessed using the provided function.
format(source) Specifies the underlying output data source.option(key, value) Adds an output option for the underlying data source.options(**options) Adds output options for the underlying data source.outputMode(outputMode) Specifies how data of a streaming
DataFrame/Dataset is written to a streamingsink.
partitionBy(*cols) Partitions the output by the given columns on the filesystem.
queryName(queryName) Specifies the name of the StreamingQuery thatcan be started with start().
start([path, format, outputMode, . . . ]) Streams the contents of the DataFrame to a datasource.
trigger(*args, **kwargs) Set the trigger for the stream query.
pyspark.sql.streaming.ForeachBatchFunction
class pyspark.sql.streaming.ForeachBatchFunction(sql_ctx, func)This is the Python implementation of Java interface ‘ForeachBatchFunction’. This wraps the user-defined ‘fore-achBatch’ function such that it can be called from the JVM when the query is active.
Methods
call(jdf, batch_id)
256 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.streaming.StreamingQuery
class pyspark.sql.streaming.StreamingQuery(jsq)A handle to a query that is executing continuously in the background as new data arrives. All these methods arethread-safe.
New in version 2.0.0.
Notes
This API is evolving.
Methods
awaitTermination([timeout]) Waits for the termination of this query, either byquery.stop() or by an exception.
exception() New in version 2.1.0.
explain([extended]) Prints the (logical and physical) plans to the consolefor debugging purpose.
processAllAvailable() Blocks until all available data in the source has beenprocessed and committed to the sink.
stop() Stop this streaming query.
Attributes
id Returns the unique id of this query that persistsacross restarts from checkpoint data.
isActive Whether this streaming query is currently active ornot.
lastProgress Returns the most recentStreamingQueryProgress update of thisstreaming query or None if there were no progressupdates
name Returns the user-specified name of the query, or nullif not specified.
recentProgress Returns an array of the most recent [[Streaming-QueryProgress]] updates for this query.
runId Returns the unique id of this query that does not per-sist across restarts.
status Returns the current status of the query.
3.2. Structured Streaming 257
pyspark Documentation, Release master
pyspark.sql.streaming.StreamingQueryException
exception pyspark.sql.streaming.StreamingQueryException(desc, stackTrace,cause=None)
Exception that stopped a StreamingQuery .
pyspark.sql.streaming.StreamingQueryManager
class pyspark.sql.streaming.StreamingQueryManager(jsqm)A class to manage all the StreamingQuery StreamingQueries active.
New in version 2.0.0.
Notes
This API is evolving.
Methods
awaitAnyTermination([timeout]) Wait until any of the queries on the associated SQL-Context has terminated since the creation of the con-text, or since resetTerminated() was called.
get(id) Returns an active query from this SQLContext orthrows exception if an active query with this namedoesn’t exist.
resetTerminated() Forget about past terminated queries so thatawaitAnyTermination() can be used again towait for new terminations.
Attributes
active Returns a list of active queries associated with thisSQLContext
3.2.2 Input and Output
DataStreamReader.csv(path[, schema, sep, . . . ]) Loads a CSV file stream and returns the result as aDataFrame.
DataStreamReader.format(source) Specifies the input data source format.DataStreamReader.json(path[, schema, . . . ]) Loads a JSON file stream and returns the results as a
DataFrame.DataStreamReader.load([path, format,schema])
Loads a data stream from a data source and returns it asa DataFrame.
DataStreamReader.option(key, value) Adds an input option for the underlying data source.DataStreamReader.options(**options) Adds input options for the underlying data source.DataStreamReader.orc(path[, mergeSchema,. . . ])
Loads a ORC file stream, returning the result as aDataFrame.
continues on next page
258 Chapter 3. API Reference
pyspark Documentation, Release master
Table 51 – continued from previous pageDataStreamReader.parquet(path[, . . . ]) Loads a Parquet file stream, returning the result as a
DataFrame.DataStreamReader.schema(schema) Specifies the input schema.DataStreamReader.text(path[, wholetext, . . . ]) Loads a text file stream and returns a DataFrame
whose schema starts with a string column named“value”, and followed by partitioned columns if thereare any.
DataStreamWriter.foreach(f) Sets the output of the streaming query to be processedusing the provided writer f.
DataStreamWriter.foreachBatch(func) Sets the output of the streaming query to be processedusing the provided function.
DataStreamWriter.format(source) Specifies the underlying output data source.DataStreamWriter.option(key, value) Adds an output option for the underlying data source.DataStreamWriter.options(**options) Adds output options for the underlying data source.DataStreamWriter.outputMode(outputMode) Specifies how data of a streaming DataFrame/Dataset is
written to a streaming sink.DataStreamWriter.partitionBy(*cols) Partitions the output by the given columns on the file
system.DataStreamWriter.queryName(queryName) Specifies the name of the StreamingQuery that can
be started with start().DataStreamWriter.start([path, format, . . . ]) Streams the contents of the DataFrame to a data
source.DataStreamWriter.trigger(*args, **kwargs) Set the trigger for the stream query.
pyspark.sql.streaming.DataStreamReader.csv
DataStreamReader.csv(path, schema=None, sep=None, encoding=None, quote=None, escape=None,comment=None, header=None, inferSchema=None, ignoreLeadingWhiteS-pace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nan-Value=None, positiveInf=None, negativeInf=None, dateFormat=None,timestampFormat=None, maxColumns=None, maxCharsPerColumn=None,maxMalformedLogPerPartition=None, mode=None, columnNameOfCor-ruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,enforceSchema=None, emptyValue=None, locale=None, lineSep=None,pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHan-dling=None)
Loads a CSV file stream and returns the result as a DataFrame.
This function will go through the input once to determine the input schema if inferSchema is enabled. Toavoid going through the entire data once, disable inferSchema option or specify the schema explicitly usingschema.
Parameters
path [str or list] string, or list of strings, for input path(s).
schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).
sep [str, optional] sets a separator (one or more characters) for each field and value. If None isset, it uses the default value, ,.
encoding [str, optional] decodes the CSV files by the given encoding type. If None is set, ituses the default value, UTF-8.
3.2. Structured Streaming 259
pyspark Documentation, Release master
quote [str, optional sets a single character used for escaping quoted values where the] separatorcan be part of the value. If None is set, it uses the default value, ". If you would like to turnoff quotations, you need to set an empty string.
escape [str, optional] sets a single character used for escaping quotes inside an already quotedvalue. If None is set, it uses the default value, \.
comment [str, optional] sets a single character used for skipping lines beginning with this char-acter. By default (None), it is disabled.
header [str or bool, optional] uses the first line as names of columns. If None is set, it uses thedefault value, false.
inferSchema [str or bool, optional] infers the input schema automatically from data. It requiresone extra pass over the data. If None is set, it uses the default value, false.
enforceSchema [str or bool, optional] If it is set to true, the specified or inferred schemawill be forcibly applied to datasource files, and headers in CSV files will be ignored. If theoption is set to false, the schema will be validated against all headers in CSV files or thefirst header in RDD if the header option is set to true. Field names in the schema andcolumn names in CSV headers are checked by their positions taking into account spark.sql.caseSensitive. If None is set, true is used by default. Though the default valueis true, it is recommended to disable the enforceSchema option to avoid incorrectresults.
ignoreLeadingWhiteSpace [str or bool, optional] a flag indicating whether or not leadingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.
ignoreTrailingWhiteSpace [str or bool, optional] a flag indicating whether or not trailingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.
nullValue [str, optional] sets the string representation of a null value. If None is set, it uses thedefault value, empty string. Since 2.0.1, this nullValue param applies to all supportedtypes including the string type.
nanValue [str, optional] sets the string representation of a non-number value. If None is set, ituses the default value, NaN.
positiveInf [str, optional] sets the string representation of a positive infinity value. If None isset, it uses the default value, Inf.
negativeInf [str, optional] sets the string representation of a negative infinity value. If None isset, it uses the default value, Inf.
dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.
timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].
maxColumns [str or int, optional] defines a hard limit of how many columns a record can have.If None is set, it uses the default value, 20480.
maxCharsPerColumn [str or int, optional] defines the maximum number of characters allowedfor any given value being read. If None is set, it uses the default value, -1meaning unlimitedlength.
260 Chapter 3. API Reference
pyspark Documentation, Release master
maxMalformedLogPerPartition [str or int, optional] this parameter is no longer used sinceSpark 2.2.0. If specified, it is ignored.
mode [str, optional] allows a mode for dealing with corrupt records during parsing. If None isset, it uses the default value, PERMISSIVE.
• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. A record with less/more tokensthan schema is not a corrupted record to CSV. When it meets a record having fewer to-kens than the length of the schema, sets null to extra fields. When the record has moretokens than the length of the schema, it drops extra tokens.
• DROPMALFORMED: ignores the whole corrupted records.
• FAILFAST: throws an exception when it meets corrupted records.
columnNameOfCorruptRecord [str, optional] allows renaming the new field havingmalformed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.
multiLine [str or bool, optional] parse one record, which may span multiple lines. If None isset, it uses the default value, false.
charToEscapeQuoteEscaping [str, optional] sets a single character used for escaping the es-cape for the quote character. If None is set, the default value is escape character when escapeand quote characters are different, \0 otherwise.
emptyValue [str, optional] sets the string representation of an empty value. If None is set, ituses the default value, empty string.
locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.
lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \\r, \\r\\n and \\n. Maximum length is 1 character.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
unescapedQuoteHandling [str, optional] defines how the CsvParser will handle values withunescaped quotes. If None is set, it uses the default value, STOP_AT_DELIMITER.
• STOP_AT_CLOSING_QUOTE: If unescaped quotes are found in the input, accumulatethe quote character and proceed parsing the value as a quoted value, until a closing quoteis found.
• BACK_TO_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters of the currentparsed value until the delimiter is found. If no delimiter is found in the value, the parserwill continue accumulating characters from the input until a delimiter or line ending isfound.
3.2. Structured Streaming 261
pyspark Documentation, Release master
• STOP_AT_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters until the de-limiter or a line ending is found in the input.
• STOP_AT_DELIMITER: If unescaped quotes are found in the input, the content parsedfor the given value will be skipped and the value set in nullValue will be produced instead.
• RAISE_ERROR: If unescaped quotes are found in the input, a TextParsingException willbe thrown.
.. versionadded:: 2.0.0
Notes
This API is evolving.
Examples
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)>>> csv_sdf.isStreamingTrue>>> csv_sdf.schema == sdf_schemaTrue
pyspark.sql.streaming.DataStreamReader.format
DataStreamReader.format(source)Specifies the input data source format.
New in version 2.0.0.
Parameters
source [str] name of the data source, e.g. ‘json’, ‘parquet’.
Notes
This API is evolving.
Examples
>>> s = spark.readStream.format("text")
262 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.streaming.DataStreamReader.json
DataStreamReader.json(path, schema=None, primitivesAsString=None, prefersDecimal=None,allowComments=None, allowUnquotedFieldNames=None, allowSingle-Quotes=None, allowNumericLeadingZero=None, allowBackslashEscapin-gAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None,dateFormat=None, timestampFormat=None, multiLine=None, allowUn-quotedControlChars=None, lineSep=None, locale=None, dropField-IfAllNull=None, encoding=None, pathGlobFilter=None, recursive-FileLookup=None, allowNonNumericNumbers=None)
Loads a JSON file stream and returns the results as a DataFrame.
JSON Lines (newline-delimited JSON) is supported by default. For JSON (one record per file), set themultiLine parameter to true.
If the schema parameter is not specified, this function goes through the input once to determine the inputschema.
New in version 2.0.0.
Parameters
path [str] string represents path to the JSON dataset, or RDD of Strings storing JSON objects.
schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).
primitivesAsString [str or bool, optional] infers all primitive values as a string type. If None isset, it uses the default value, false.
prefersDecimal [str or bool, optional] infers all floating-point values as a decimal type. If thevalues do not fit in decimal, then it infers them as doubles. If None is set, it uses the defaultvalue, false.
allowComments [str or bool, optional] ignores Java/C++ style comment in JSON records. IfNone is set, it uses the default value, false.
allowUnquotedFieldNames [str or bool, optional] allows unquoted JSON field names. If Noneis set, it uses the default value, false.
allowSingleQuotes [str or bool, optional] allows single quotes in addition to double quotes. IfNone is set, it uses the default value, true.
allowNumericLeadingZero [str or bool, optional] allows leading zeros in numbers (e.g.00012). If None is set, it uses the default value, false.
allowBackslashEscapingAnyCharacter [str or bool, optional] allows accepting quoting of allcharacter using backslash quoting mechanism. If None is set, it uses the default value,false.
mode [str, optional] allows a mode for dealing with corrupt records during parsing. If None isset, it uses the default value, PERMISSIVE.
• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. When inferring a schema, it im-plicitly adds a columnNameOfCorruptRecord field in an output schema.
• DROPMALFORMED: ignores the whole corrupted records.
3.2. Structured Streaming 263
pyspark Documentation, Release master
• FAILFAST: throws an exception when it meets corrupted records.
columnNameOfCorruptRecord [str, optional] allows renaming the new field havingmalformed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.
dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.
timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].
multiLine [str or bool, optional] parse one record, which may span multiple lines, per file. IfNone is set, it uses the default value, false.
allowUnquotedControlChars [str or bool, optional] allows JSON Strings to contain unquotedcontrol characters (ASCII characters with value less than 32, including tab and line feedcharacters) or not.
lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \r, \r\n and \n.
locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.
dropFieldIfAllNull [str or bool, optional] whether to ignore column of all null values or emptyarray/struct during schema inference. If None is set, it uses the default value, false.
encoding [str or bool, optional] allows to forcibly set one of standard basic or extended encod-ing for the JSON files. For example UTF-16BE, UTF-32LE. If None is set, the encoding ofinput JSON will be detected automatically when the multiLine option is set to true.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
allowNonNumericNumbers [str or bool, optional] allows JSON parser to recognize set of“Not-a-Number” (NaN) tokens as legal floating number values. If None is set, it uses thedefault value, true.
• +INF: for positive infinity, as well as alias of +Infinity and Infinity.
• -INF: for negative infinity, alias -Infinity.
• NaN: for other not-a-numbers, like result of division by zero.
264 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This API is evolving.
Examples
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)>>> json_sdf.isStreamingTrue>>> json_sdf.schema == sdf_schemaTrue
pyspark.sql.streaming.DataStreamReader.load
DataStreamReader.load(path=None, format=None, schema=None, **options)Loads a data stream from a data source and returns it as a DataFrame.
New in version 2.0.0.
Parameters
path [str, optional] optional string for file-system backed data sources.
format [str, optional] optional string for format of the data source. Default to ‘parquet’.
schema [pyspark.sql.types.StructType or str, optional] optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For example col0INT, col1 DOUBLE).
**options [dict] all other string options
Notes
This API is evolving.
Examples
>>> json_sdf = spark.readStream.format("json") \... .schema(sdf_schema) \... .load(tempfile.mkdtemp())>>> json_sdf.isStreamingTrue>>> json_sdf.schema == sdf_schemaTrue
3.2. Structured Streaming 265
pyspark Documentation, Release master
pyspark.sql.streaming.DataStreamReader.option
DataStreamReader.option(key, value)Adds an input option for the underlying data source.
You can set the following option(s) for reading files:
• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
New in version 2.0.0.
Notes
This API is evolving.
Examples
>>> s = spark.readStream.option("x", 1)
pyspark.sql.streaming.DataStreamReader.options
DataStreamReader.options(**options)Adds input options for the underlying data source.
You can set the following option(s) for reading files:
• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
New in version 2.0.0.
266 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This API is evolving.
Examples
>>> s = spark.readStream.options(x="1", y=2)
pyspark.sql.streaming.DataStreamReader.orc
DataStreamReader.orc(path, mergeSchema=None, pathGlobFilter=None, recursive-FileLookup=None)
Loads a ORC file stream, returning the result as a DataFrame.
New in version 2.3.0.
Parameters
mergeSchema [str or bool, optional] sets whether we should merge schemas collected from allORC part-files. This will override spark.sql.orc.mergeSchema. The default valueis specified in spark.sql.orc.mergeSchema.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery.
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
Examples
>>> orc_sdf = spark.readStream.schema(sdf_schema).orc(tempfile.mkdtemp())>>> orc_sdf.isStreamingTrue>>> orc_sdf.schema == sdf_schemaTrue
pyspark.sql.streaming.DataStreamReader.parquet
DataStreamReader.parquet(path, mergeSchema=None, pathGlobFilter=None, recursive-FileLookup=None)
Loads a Parquet file stream, returning the result as a DataFrame.
New in version 2.0.0.
Parameters
mergeSchema [str or bool, optional] sets whether we should merge schemas collected fromall Parquet part-files. This will override spark.sql.parquet.mergeSchema. Thedefault value is specified in spark.sql.parquet.mergeSchema.
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery.
3.2. Structured Streaming 267
pyspark Documentation, Release master
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
Examples
>>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp())>>> parquet_sdf.isStreamingTrue>>> parquet_sdf.schema == sdf_schemaTrue
pyspark.sql.streaming.DataStreamReader.schema
DataStreamReader.schema(schema)Specifies the input schema.
Some data sources (e.g. JSON) can infer the input schema automatically from data. By specifying the schemahere, the underlying data source can skip the schema inference step, and thus speed up data loading.
New in version 2.0.0.
Parameters
schema [pyspark.sql.types.StructType or str] a pyspark.sql.types.StructType object or a DDL-formatted string (For example col0 INT, col1DOUBLE).
Notes
This API is evolving.
Examples
>>> s = spark.readStream.schema(sdf_schema)>>> s = spark.readStream.schema("col0 INT, col1 DOUBLE")
pyspark.sql.streaming.DataStreamReader.text
DataStreamReader.text(path, wholetext=False, lineSep=None, pathGlobFilter=None, recursive-FileLookup=None)
Loads a text file stream and returns a DataFrame whose schema starts with a string column named “value”,and followed by partitioned columns if there are any. The text files must be encoded as UTF-8.
By default, each line in the text file is a new row in the resulting DataFrame.
New in version 2.0.0.
Parameters
paths [str or list] string, or list of strings, for input path(s).
wholetext [str or bool, optional] if true, read each file from input path(s) as a single row.
lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \r, \r\n and \n.
268 Chapter 3. API Reference
pyspark Documentation, Release master
pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery.
recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa
Notes
This API is evolving.
Examples
>>> text_sdf = spark.readStream.text(tempfile.mkdtemp())>>> text_sdf.isStreamingTrue>>> "value" in str(text_sdf.schema)True
pyspark.sql.streaming.DataStreamWriter.foreach
DataStreamWriter.foreach(f)Sets the output of the streaming query to be processed using the provided writer f. This is often used to writethe output of a streaming query to arbitrary storage systems. The processing logic can be specified in two ways.
1. A function that takes a row as input. This is a simple way to express your processing logic. Note thatthis does not allow you to deduplicate generated data when failures cause reprocessing of some inputdata. That would require you to specify the processing logic in the next way.
2. An object with a process method and optional open and close methods. The object can have thefollowing methods.
• open(partition_id, epoch_id): Optional method that initializes the processing(for example, open a connection, start a transaction, etc). Additionally, you can use thepartition_id and epoch_id to deduplicate regenerated data (discussed later).
• process(row): Non-optional method that processes each Row.
• close(error): Optional method that finalizes and cleans up (for example, close connec-tion, commit transaction, etc.) after all rows have been processed.
The object will be used by Spark in the following way.
• A single copy of this object is responsible of all the data generated by a single task in aquery. In other words, one instance is responsible for processing one partition of the datagenerated in a distributed manner.
• This object must be serializable because each task will get a fresh serialized-deserializedcopy of the provided object. Hence, it is strongly recommended that any initialization forwriting data (e.g. opening a connection or starting a transaction) is done after the open(. . . )method has been called, which signifies that the task is ready to generate data.
• The lifecycle of the methods are as follows.
For each partition with partition_id:
. . . For each batch/epoch of streaming data with epoch_id:
3.2. Structured Streaming 269
pyspark Documentation, Release master
. . . . . . . Method open(partitionId, epochId) is called.
. . . . . . . If open(...) returns true, for each row in the partition and batch/epoch,method process(row) is called.
. . . . . . . Method close(errorOrNull) is called with error (if any) seen whileprocessing rows.
Important points to note:
• The partitionId and epochId can be used to deduplicate generated data when failures causereprocessing of some input data. This depends on the execution mode of the query. If thestreaming query is being executed in the micro-batch mode, then every partition representedby a unique tuple (partition_id, epoch_id) is guaranteed to have the same data. Hence, (parti-tion_id, epoch_id) can be used to deduplicate and/or transactionally commit data and achieveexactly-once guarantees. However, if the streaming query is being executed in the continuousmode, then this guarantee does not hold and therefore should not be used for deduplication.
• The close() method (if exists) will be called if open() method exists and returns success-fully (irrespective of the return value), except if the Python crashes in the middle.
New in version 2.4.0.
Notes
This API is evolving.
Examples
>>> # Print every row using a function>>> def print_row(row):... print(row)...>>> writer = sdf.writeStream.foreach(print_row)>>> # Print every row using a object with process() method>>> class RowPrinter:... def open(self, partition_id, epoch_id):... print("Opened %d, %d" % (partition_id, epoch_id))... return True... def process(self, row):... print(row)... def close(self, error):... print("Closed with error: %s" % str(error))...>>> writer = sdf.writeStream.foreach(RowPrinter())
270 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.streaming.DataStreamWriter.foreachBatch
DataStreamWriter.foreachBatch(func)Sets the output of the streaming query to be processed using the provided function. This is supported onlythe in the micro-batch execution modes (that is, when the trigger is not continuous). In every micro-batch, theprovided function will be called in every micro-batch with (i) the output rows as a DataFrame and (ii) the batchidentifier. The batchId can be used deduplicate and transactionally write the output (that is, the provided Dataset)to external systems. The output DataFrame is guaranteed to exactly same for the same batchId (assuming alloperations are deterministic in the query).
New in version 2.4.0.
Notes
This API is evolving.
Examples
>>> def func(batch_df, batch_id):... batch_df.collect()...>>> writer = sdf.writeStream.foreachBatch(func)
pyspark.sql.streaming.DataStreamWriter.format
DataStreamWriter.format(source)Specifies the underlying output data source.
New in version 2.0.0.
Parameters
source [str] string, name of the data source, which for now can be ‘parquet’.
Notes
This API is evolving.
Examples
>>> writer = sdf.writeStream.format('json')
3.2. Structured Streaming 271
pyspark Documentation, Release master
pyspark.sql.streaming.DataStreamWriter.option
DataStreamWriter.option(key, value)Adds an output option for the underlying data source.
You can set the following option(s) for writing files:
• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
New in version 2.0.0.
Notes
This API is evolving.
pyspark.sql.streaming.DataStreamWriter.options
DataStreamWriter.options(**options)Adds output options for the underlying data source.
You can set the following option(s) for writing files:
• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:
– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.
– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.
Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.
New in version 2.0.0.
Notes
This API is evolving.
272 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.streaming.DataStreamWriter.outputMode
DataStreamWriter.outputMode(outputMode)Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
New in version 2.0.0.
Options include:
• append: Only the new rows in the streaming DataFrame/Dataset will be written to the sink
• complete: All the rows in the streaming DataFrame/Dataset will be written to the sink every timethese is some updates
• update: only the rows that were updated in the streaming DataFrame/Dataset will be written to thesink every time there are some updates. If the query doesn’t contain aggregations, it will be equivalentto append mode.
Notes
This API is evolving.
Examples
>>> writer = sdf.writeStream.outputMode('append')
pyspark.sql.streaming.DataStreamWriter.partitionBy
DataStreamWriter.partitionBy(*cols)Partitions the output by the given columns on the file system.
If specified, the output is laid out on the file system similar to Hive’s partitioning scheme.
New in version 2.0.0.
Parameters
cols [str or list] name of columns
Notes
This API is evolving.
pyspark.sql.streaming.DataStreamWriter.queryName
DataStreamWriter.queryName(queryName)Specifies the name of the StreamingQuery that can be started with start(). This name must be uniqueamong all the currently active queries in the associated SparkSession.
New in version 2.0.0.
Parameters
queryName [str] unique name for the query
3.2. Structured Streaming 273
pyspark Documentation, Release master
Notes
This API is evolving.
Examples
>>> writer = sdf.writeStream.queryName('streaming_query')
pyspark.sql.streaming.DataStreamWriter.start
DataStreamWriter.start(path=None, format=None, outputMode=None, partitionBy=None, query-Name=None, **options)
Streams the contents of the DataFrame to a data source.
The data source is specified by the format and a set of options. If format is not specified, the default datasource configured by spark.sql.sources.default will be used.
New in version 2.0.0.
Parameters
path [str, optional] the path in a Hadoop supported file system
format [str, optional] the format used to save
outputMode [str, optional] specifies how data of a streaming DataFrame/Dataset is written to astreaming sink.
• append: Only the new rows in the streaming DataFrame/Dataset will be written to thesink
• complete: All the rows in the streaming DataFrame/Dataset will be written to the sinkevery time these is some updates
• update: only the rows that were updated in the streaming DataFrame/Dataset will bewritten to the sink every time there are some updates. If the query doesn’t contain aggre-gations, it will be equivalent to append mode.
partitionBy [str or list, optional] names of partitioning columns
queryName [str, optional] unique name for the query
**options [dict] All other string options. You may want to provide a checkpointLocation formost streams, however it is not required for a memory stream.
Notes
This API is evolving.
274 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()>>> sq.isActiveTrue>>> sq.name'this_query'>>> sq.stop()>>> sq.isActiveFalse>>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start(... queryName='that_query', outputMode="append", format='memory')>>> sq.name'that_query'>>> sq.isActiveTrue>>> sq.stop()
pyspark.sql.streaming.DataStreamWriter.trigger
DataStreamWriter.trigger(*args, **kwargs)Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalentto setting the trigger to processingTime='0 seconds'.
New in version 2.0.0.
Parameters
processingTime [str, optional] a processing time interval as a string, e.g. ‘5 seconds’, ‘1minute’. Set a trigger that runs a microbatch query periodically based on the processingtime. Only one trigger can be set.
once [bool, optional] if set to True, set a trigger that processes only one batch of data in astreaming query then terminates the query. Only one trigger can be set.
continuous [str, optional] a time interval as a string, e.g. ‘5 seconds’, ‘1 minute’. Set a triggerthat runs a continuous query with a given checkpoint interval. Only one trigger can be set.
Notes
This API is evolving.
Examples
>>> # trigger the query for execution every 5 seconds>>> writer = sdf.writeStream.trigger(processingTime='5 seconds')>>> # trigger the query for just once batch of data>>> writer = sdf.writeStream.trigger(once=True)>>> # trigger the query for execution every 5 seconds>>> writer = sdf.writeStream.trigger(continuous='5 seconds')
3.2. Structured Streaming 275
pyspark Documentation, Release master
3.2.3 Query Management
StreamingQuery.awaitTermination([timeout]) Waits for the termination of this query, either byquery.stop() or by an exception.
StreamingQuery.exception() New in version 2.1.0.
StreamingQuery.explain([extended]) Prints the (logical and physical) plans to the console fordebugging purpose.
StreamingQuery.id Returns the unique id of this query that persists acrossrestarts from checkpoint data.
StreamingQuery.isActive Whether this streaming query is currently active or not.StreamingQuery.lastProgress Returns the most recent
StreamingQueryProgress update of thisstreaming query or None if there were no progressupdates
StreamingQuery.name Returns the user-specified name of the query, or null ifnot specified.
StreamingQuery.processAllAvailable() Blocks until all available data in the source has beenprocessed and committed to the sink.
StreamingQuery.recentProgress Returns an array of the most recent [[Streaming-QueryProgress]] updates for this query.
StreamingQuery.runId Returns the unique id of this query that does not persistacross restarts.
StreamingQuery.status Returns the current status of the query.StreamingQuery.stop() Stop this streaming query.StreamingQueryManager.active Returns a list of active queries associated with this SQL-
ContextStreamingQueryManager.awaitAnyTermination([. . . ])
Wait until any of the queries on the associated SQLCon-text has terminated since the creation of the context, orsince resetTerminated() was called.
StreamingQueryManager.get(id) Returns an active query from this SQLContext or throwsexception if an active query with this name doesn’t exist.
StreamingQueryManager.resetTerminated()
Forget about past terminated queries so thatawaitAnyTermination() can be used againto wait for new terminations.
pyspark.sql.streaming.StreamingQuery.awaitTermination
StreamingQuery.awaitTermination(timeout=None)Waits for the termination of this query, either by query.stop() or by an exception. If the query has termi-nated with an exception, then the exception will be thrown. If timeout is set, it returns whether the query hasterminated or not within the timeout seconds.
If the query has terminated, then all subsequent calls to this method will either return immediately (if the querywas terminated by stop()), or throw the exception immediately (if the query has terminated with exception).
throws StreamingQueryException, if this query has terminated with an exception
New in version 2.0.
276 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.streaming.StreamingQuery.exception
StreamingQuery.exception()New in version 2.1.0.
Returns
StreamingQueryException the StreamingQueryException if the query was terminatedby an exception, or None.
pyspark.sql.streaming.StreamingQuery.explain
StreamingQuery.explain(extended=False)Prints the (logical and physical) plans to the console for debugging purpose.
New in version 2.1.0.
Parameters
extended [bool, optional] default False. If False, prints only the physical plan.
Examples
>>> sq = sdf.writeStream.format('memory').queryName('query_explain').start()>>> sq.processAllAvailable() # Wait a bit to generate the runtime plans.>>> sq.explain()== Physical Plan ==...>>> sq.explain(True)== Parsed Logical Plan ==...== Analyzed Logical Plan ==...== Optimized Logical Plan ==...== Physical Plan ==...>>> sq.stop()
pyspark.sql.streaming.StreamingQuery.id
property StreamingQuery.idReturns the unique id of this query that persists across restarts from checkpoint data. That is, this id is generatedwhen a query is started for the first time, and will be the same every time it is restarted from checkpoint data.There can only be one query with the same id active in a Spark cluster. Also see, runId.
New in version 2.0.
3.2. Structured Streaming 277
pyspark Documentation, Release master
pyspark.sql.streaming.StreamingQuery.isActive
property StreamingQuery.isActiveWhether this streaming query is currently active or not.
New in version 2.0.
pyspark.sql.streaming.StreamingQuery.lastProgress
property StreamingQuery.lastProgressReturns the most recent StreamingQueryProgress update of this streaming query or None if there wereno progress updates
New in version 2.1.0.
Returns
dict
pyspark.sql.streaming.StreamingQuery.name
property StreamingQuery.nameReturns the user-specified name of the query, or null if not specified. This name can be specified inthe org.apache.spark.sql.streaming.DataStreamWriter as dataframe.writeStream.queryName(“query”).start().This name, if set, must be unique across all active queries.
New in version 2.0.
pyspark.sql.streaming.StreamingQuery.processAllAvailable
StreamingQuery.processAllAvailable()Blocks until all available data in the source has been processed and committed to the sink. This method isintended for testing.
New in version 2.0.0.
Notes
In the case of continually arriving data, this method may block forever. Additionally, this method is onlyguaranteed to block until data that has been synchronously appended data to a stream source prior to invocation.(i.e. getOffset must immediately reflect the addition).
pyspark.sql.streaming.StreamingQuery.recentProgress
property StreamingQuery.recentProgressReturns an array of the most recent [[StreamingQueryProgress]] updates for this query. The num-ber of progress updates retained for each stream is configured by Spark session configurationspark.sql.streaming.numRecentProgressUpdates.
New in version 2.1.
278 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.streaming.StreamingQuery.runId
property StreamingQuery.runIdReturns the unique id of this query that does not persist across restarts. That is, every query that is started (orrestarted from checkpoint) will have a different runId.
New in version 2.1.
pyspark.sql.streaming.StreamingQuery.status
property StreamingQuery.statusReturns the current status of the query.
New in version 2.1.
pyspark.sql.streaming.StreamingQuery.stop
StreamingQuery.stop()Stop this streaming query.
New in version 2.0.
pyspark.sql.streaming.StreamingQueryManager.active
property StreamingQueryManager.activeReturns a list of active queries associated with this SQLContext
New in version 2.0.0.
Examples
>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()>>> sqm = spark.streams>>> # get the list of active streaming queries>>> [q.name for q in sqm.active]['this_query']>>> sq.stop()
pyspark.sql.streaming.StreamingQueryManager.awaitAnyTermination
StreamingQueryManager.awaitAnyTermination(timeout=None)Wait until any of the queries on the associated SQLContext has terminated since the creation of the context, orsince resetTerminated() was called. If any query was terminated with an exception, then the exceptionwill be thrown. If timeout is set, it returns whether the query has terminated or not within the timeout seconds.
If a query has terminated, then subsequent calls to awaitAnyTermination() will either return immedi-ately (if the query was terminated by query.stop()), or throw the exception immediately (if the query wasterminated with exception). Use resetTerminated() to clear past terminations and wait for new termina-tions.
In the case where multiple queries have terminated since resetTermination() was called, if any query hasterminated with exception, then awaitAnyTermination() will throw any of the exception. For correctly
3.2. Structured Streaming 279
pyspark Documentation, Release master
documenting exceptions across multiple queries, users need to stop all of them after any of them terminates withexception, and then check the query.exception() for each query.
throws StreamingQueryException, if this query has terminated with an exception
New in version 2.0.
pyspark.sql.streaming.StreamingQueryManager.get
StreamingQueryManager.get(id)Returns an active query from this SQLContext or throws exception if an active query with this name doesn’texist.
New in version 2.0.0.
Examples
>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()>>> sq.name'this_query'>>> sq = spark.streams.get(sq.id)>>> sq.isActiveTrue>>> sq = sqlContext.streams.get(sq.id)>>> sq.isActiveTrue>>> sq.stop()
pyspark.sql.streaming.StreamingQueryManager.resetTerminated
StreamingQueryManager.resetTerminated()Forget about past terminated queries so that awaitAnyTermination() can be used again to wait for newterminations.
New in version 2.0.0.
Examples
>>> spark.streams.resetTerminated()
3.3 ML
3.3.1 ML Pipeline APIs
Transformer() Abstract class for transformers that transform onedataset into another.
UnaryTransformer() Abstract class for transformers that take one input col-umn, apply transformation, and output the result as anew column.
continues on next page
280 Chapter 3. API Reference
pyspark Documentation, Release master
Table 53 – continued from previous pageEstimator() Abstract class for estimators that fit models to data.Model() Abstract class for models that are fitted by estimators.Predictor() Estimator for prediction tasks (regression and classifica-
tion).PredictionModel() Model for prediction tasks (regression and classifica-
tion).Pipeline(*args, **kwargs) A simple pipeline, which acts as an estimator.PipelineModel(stages) Represents a compiled pipeline with transformers and
fitted models.
Transformer
class pyspark.ml.TransformerAbstract class for transformers that transform one dataset into another.
New in version 1.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.
3.3. ML 281
pyspark Documentation, Release master
Attributes
params Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
282 Chapter 3. API Reference
pyspark Documentation, Release master
set(param, value)Sets a parameter in the embedded param map.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
UnaryTransformer
class pyspark.ml.UnaryTransformerAbstract class for transformers that take one input column, apply transformation, and output the result as a newcolumn.
New in version 2.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
createTransformFunc() Creates the transform function using the given parammap.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.continues on next page
3.3. ML 283
pyspark Documentation, Release master
Table 56 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.outputDataType() Returns the data type of the output column.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.transformSchema(schema)validateInputType(inputType) Validates the input type.
Attributes
inputColoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
abstract createTransformFunc()Creates the transform function using the given param map. The input param map already takes account ofthe embedded param map. So the param values should be determined solely by the input param map.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra
284 Chapter 3. API Reference
pyspark Documentation, Release master
values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
abstract outputDataType()Returns the data type of the output column.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
transformSchema(schema)
3.3. ML 285
pyspark Documentation, Release master
abstract validateInputType(inputType)Validates the input type. Throw an exception if it is invalid.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Estimator
class pyspark.ml.EstimatorAbstract class for estimators that fit models to data.
New in version 1.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.
286 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
params Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
3.3. ML 287
pyspark Documentation, Release master
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
set(param, value)Sets a parameter in the embedded param map.
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Model
class pyspark.ml.ModelAbstract class for models that are fitted by estimators.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
continues on next page
288 Chapter 3. API Reference
pyspark Documentation, Release master
Table 60 – continued from previous pageexplainParams() Returns the documentation of all params with their
optionally default values and user-supplied values.extractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.
Attributes
params Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
3.3. ML 289
pyspark Documentation, Release master
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
set(param, value)Sets a parameter in the embedded param map.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Predictor
class pyspark.ml.PredictorEstimator for prediction tasks (regression and classification).
290 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLabelCol(value) Sets the value of labelCol.setPredictionCol(value) Sets the value of predictionCol.
Attributes
featuresCollabelColparams Returns all params ordered by name.predictionCol
3.3. ML 291
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
292 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFeaturesCol()Gets the value of featuresCol or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
3.3. ML 293
pyspark Documentation, Release master
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
PredictionModel
class pyspark.ml.PredictionModelModel for prediction tasks (regression and classification).
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.predict(value) Predict label for the given features.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.
294 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
featuresCollabelColnumFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getFeaturesCol()Gets the value of featuresCol or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
3.3. ML 295
pyspark Documentation, Release master
getPredictionCol()Gets the value of predictionCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
abstract predict(value)Predict label for the given features.
New in version 3.0.0.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
296 Chapter 3. API Reference
pyspark Documentation, Release master
Pipeline
class pyspark.ml.Pipeline(*args, **kwargs)A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which iseither an Estimator or a Transformer. When Pipeline.fit() is called, the stages are executed inorder. If a stage is an Estimator, its Estimator.fit() method will be called on the input dataset to fita model. Then the model, which is a transformer, will be used to transform the dataset as the input to the nextstage. If a stage is a Transformer, its Transformer.transform() method will be called to producethe dataset for the next stage. The fitted model from a Pipeline is a PipelineModel, which consists offitted models and transformers, corresponding to the pipeline stages. If stages is an empty list, the pipeline actsas an identity transformer.
New in version 1.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.getStages() Get pipeline stages.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setParams(self, \*[, stages]) Sets params for Pipeline.
continues on next page
3.3. ML 297
pyspark Documentation, Release master
Table 66 – continued from previous pagesetStages(value) Set pipeline stages.write() Returns an MLWriter instance for this ML instance.
Attributes
params Returns all params ordered by name.stages
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance.
New in version 1.4.0.
Parameters
extra [dict, optional] extra parameters
Returns
Pipeline new instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
298 Chapter 3. API Reference
pyspark Documentation, Release master
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getStages()Get pipeline stages.
New in version 1.3.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
New in version 2.0.0.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setParams(self, \*, stages=None)Sets params for Pipeline.
New in version 1.3.0.
setStages(value)Set pipeline stages.
New in version 1.3.0.
3.3. ML 299
pyspark Documentation, Release master
Parameters
value [list] of pyspark.ml.Transformer or pyspark.ml.Estimator
Returns
Pipeline the pipeline instance
write()Returns an MLWriter instance for this ML instance.
New in version 2.0.0.
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stages = Param(parent='undefined', name='stages', doc='a list of pipeline stages')
PipelineModel
class pyspark.ml.PipelineModel(stages)Represents a compiled pipeline with transformers and fitted models.
New in version 1.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.
continues on next page
300 Chapter 3. API Reference
pyspark Documentation, Release master
Table 68 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
params Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance.
New in version 1.4.0.
Parameters extra – extra parameters
Returns new instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 301
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
New in version 2.0.0.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
New in version 2.0.0.
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
3.3.2 Parameters
Param(parent, name, doc[, typeConverter]) A param with self-contained documentation.Params() Components that take parameters.TypeConverters() Factory methods for common type conversion functions
for Param.typeConverter.
302 Chapter 3. API Reference
pyspark Documentation, Release master
Param
class pyspark.ml.param.Param(parent, name, doc, typeConverter=None)A param with self-contained documentation.
New in version 1.3.0.
Params
class pyspark.ml.param.ParamsComponents that take parameters. This also provides an internal param map to store parameter values attachedto the instance.
New in version 1.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.
3.3. ML 303
pyspark Documentation, Release master
Attributes
params Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
304 Chapter 3. API Reference
pyspark Documentation, Release master
set(param, value)Sets a parameter in the embedded param map.
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
TypeConverters
class pyspark.ml.param.TypeConvertersFactory methods for common type conversion functions for Param.typeConverter.
New in version 2.0.0.
Methods
identity(value) Dummy converter that just returns value.toBoolean(value) Convert a value to a boolean, if possible.toFloat(value) Convert a value to a float, if possible.toInt(value) Convert a value to an int, if possible.toList(value) Convert a value to a list, if possible.toListFloat(value) Convert a value to list of floats, if possible.toListInt(value) Convert a value to list of ints, if possible.toListListFloat(value) Convert a value to list of list of floats, if possible.toListString(value) Convert a value to list of strings, if possible.toMatrix(value) Convert a value to a MLlib Matrix, if possible.toString(value) Convert a value to a string, if possible.toVector(value) Convert a value to a MLlib Vector, if possible.
Methods Documentation
static identity(value)Dummy converter that just returns value.
static toBoolean(value)Convert a value to a boolean, if possible.
static toFloat(value)Convert a value to a float, if possible.
static toInt(value)Convert a value to an int, if possible.
static toList(value)Convert a value to a list, if possible.
static toListFloat(value)Convert a value to list of floats, if possible.
static toListInt(value)Convert a value to list of ints, if possible.
3.3. ML 305
pyspark Documentation, Release master
static toListListFloat(value)Convert a value to list of list of floats, if possible.
static toListString(value)Convert a value to list of strings, if possible.
static toMatrix(value)Convert a value to a MLlib Matrix, if possible.
static toString(value)Convert a value to a string, if possible.
static toVector(value)Convert a value to a MLlib Vector, if possible.
3.3.3 Feature
ANOVASelector(*args, **kwargs) ANOVA F-value Classification selector, which selectscontinuous features to use for predicting a categoricallabel.
ANOVASelectorModel([java_model]) Model fitted by ANOVASelector.Binarizer(*args, **kwargs) Binarize a column of continuous features given a thresh-
old.BucketedRandomProjectionLSH(*args,**kwargs)
LSH class for Euclidean distance metrics.
BucketedRandomProjectionLSHModel([java_model])Model fitted by BucketedRandomProjectionLSH ,where multiple random vectors are stored.
Bucketizer(*args, **kwargs) Maps a column of continuous features to a column offeature buckets.
ChiSqSelector(*args, **kwargs) Chi-Squared feature selection, which selects categoricalfeatures to use for predicting a categorical label.
ChiSqSelectorModel([java_model]) Model fitted by ChiSqSelector.CountVectorizer(*args, **kwargs) Extracts a vocabulary from document collections and
generates a CountVectorizerModel.CountVectorizerModel([java_model]) Model fitted by CountVectorizer.DCT(*args, **kwargs) A feature transformer that takes the 1D discrete cosine
transform of a real vector.ElementwiseProduct(*args, **kwargs) Outputs the Hadamard product (i.e., the element-wise
product) of each input vector with a provided “weight”vector.
FeatureHasher(*args, **kwargs) Feature hashing projects a set of categorical or numeri-cal features into a feature vector of specified dimension(typically substantially smaller than that of the originalfeature space).
FValueSelector(*args, **kwargs) F Value Regression feature selector, which selects con-tinuous features to use for predicting a continuous label.
FValueSelectorModel([java_model]) Model fitted by FValueSelector.HashingTF(*args, **kwargs) Maps a sequence of terms to their term frequencies us-
ing the hashing trick.IDF(*args, **kwargs) Compute the Inverse Document Frequency (IDF) given
a collection of documents.IDFModel([java_model]) Model fitted by IDF.
continues on next page
306 Chapter 3. API Reference
pyspark Documentation, Release master
Table 74 – continued from previous pageImputer(*args, **kwargs) Imputation estimator for completing missing values, ei-
ther using the mean or the median of the columns inwhich the missing values are located.
ImputerModel([java_model]) Model fitted by Imputer.IndexToString(*args, **kwargs) A pyspark.ml.base.Transformer that maps a
column of indices back to a new column of correspond-ing string values.
Interaction(*args, **kwargs) Implements the feature interaction transform.MaxAbsScaler(*args, **kwargs) Rescale each feature individually to range [-1, 1] by di-
viding through the largest maximum absolute value ineach feature.
MaxAbsScalerModel([java_model]) Model fitted by MaxAbsScaler.MinHashLSH(*args, **kwargs) LSH class for Jaccard distance.MinHashLSHModel([java_model]) Model produced by MinHashLSH , where where mul-
tiple hash functions are stored.MinMaxScaler(*args, **kwargs) Rescale each feature individually to a common range
[min, max] linearly using column summary statistics,which is also known as min-max normalization orRescaling.
MinMaxScalerModel([java_model]) Model fitted by MinMaxScaler.NGram(*args, **kwargs) A feature transformer that converts the input array of
strings into an array of n-grams.Normalizer(*args, **kwargs) Normalize a vector to have unit norm using the given
p-norm.OneHotEncoder(*args, **kwargs) A one-hot encoder that maps a column of category in-
dices to a column of binary vectors, with at most a sin-gle one-value per row that indicates the input categoryindex.
OneHotEncoderModel([java_model]) Model fitted by OneHotEncoder.PCA(*args, **kwargs) PCA trains a model to project vectors to a lower dimen-
sional space of the top k principal components.PCAModel([java_model]) Model fitted by PCA.PolynomialExpansion(*args, **kwargs) Perform feature expansion in a polynomial space.QuantileDiscretizer(*args, **kwargs) QuantileDiscretizer takes a column with con-
tinuous features and outputs a column with binned cat-egorical features.
RobustScaler(*args, **kwargs) RobustScaler removes the median and scales the dataaccording to the quantile range.
RobustScalerModel([java_model]) Model fitted by RobustScaler.RegexTokenizer(*args, **kwargs) A regex based tokenizer that extracts tokens either by
using the provided regex pattern (in Java dialect) to splitthe text (default) or repeatedly matching the regex (ifgaps is false).
RFormula(*args, **kwargs) Implements the transforms required for fitting a datasetagainst an R model formula.
RFormulaModel([java_model]) Model fitted by RFormula.SQLTransformer(*args, **kwargs) Implements the transforms which are defined by SQL
statement.StandardScaler(*args, **kwargs) Standardizes features by removing the mean and scaling
to unit variance using column summary statistics on thesamples in the training set.
continues on next page
3.3. ML 307
pyspark Documentation, Release master
Table 74 – continued from previous pageStandardScalerModel([java_model]) Model fitted by StandardScaler.StopWordsRemover(*args, **kwargs) A feature transformer that filters out stop words from
input.StringIndexer(*args, **kwargs) A label indexer that maps a string column of labels to
an ML column of label indices.StringIndexerModel([java_model]) Model fitted by StringIndexer.Tokenizer(*args, **kwargs) A tokenizer that converts the input string to lowercase
and then splits it by white spaces.VarianceThresholdSelector(*args, **kwargs) Feature selector that removes all low-variance features.VarianceThresholdSelectorModel([java_model])Model fitted by VarianceThresholdSelector.VectorAssembler(*args, **kwargs) A feature transformer that merges multiple columns into
a vector column.VectorIndexer(*args, **kwargs) Class for indexing categorical feature columns in a
dataset of Vector.VectorIndexerModel([java_model]) Model fitted by VectorIndexer.VectorSizeHint(*args, **kwargs) A feature transformer that adds size information to the
metadata of a vector column.VectorSlicer(*args, **kwargs) This class takes a feature vector and outputs a new fea-
ture vector with a subarray of the original features.Word2Vec(*args, **kwargs) Word2Vec trains a model of Map(String, Vector), i.e.Word2VecModel([java_model]) Model fitted by Word2Vec.
ANOVASelector
class pyspark.ml.feature.ANOVASelector(*args, **kwargs)ANOVA F-value Classification selector, which selects continuous features to use for predicting a categoricallabel. The selector supports different selection methods: numTopFeatures, percentile, fpr, fdr, fwe.
• numTopFeatures chooses a fixed number of top features according to a F value classification test.
• percentile is similar but chooses a fraction of all features instead of a fixed number.
• fpr chooses all features whose p-values are below a threshold, thus controlling the false positive rate ofselection.
• fdr uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below athreshold.
• fwe chooses all features whose p-values are below a threshold. The threshold is scaled by 1 / numFeatures,thus controlling the family-wise error rate of selection.
By default, the selection method is numTopFeatures, with the default number of top features set to 50.
New in version 3.1.0.
308 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([1.7, 4.4, 7.6, 5.8, 9.6, 2.3]), 3.0),... (Vectors.dense([8.8, 7.3, 5.7, 7.3, 2.2, 4.1]), 2.0),... (Vectors.dense([1.2, 9.5, 2.5, 3.1, 8.7, 2.5]), 1.0),... (Vectors.dense([3.7, 9.2, 6.1, 4.1, 7.5, 3.8]), 2.0),... (Vectors.dense([8.9, 5.2, 7.8, 8.3, 5.2, 3.0]), 4.0),... (Vectors.dense([7.9, 8.5, 9.2, 4.0, 9.4, 2.1]), 4.0)],... ["features", "label"])>>> selector = ANOVASelector(numTopFeatures=1, outputCol="selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")ANOVASelectorModel...>>> model.transform(df).head().selectedFeaturesDenseVector([7.6])>>> model.selectedFeatures[2]>>> anovaSelectorPath = temp_path + "/anova-selector">>> selector.save(anovaSelectorPath)>>> loadedSelector = ANOVASelector.load(anovaSelectorPath)>>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()True>>> modelPath = temp_path + "/anova-selector-model">>> model.save(modelPath)>>> loadedModel = ANOVASelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
continues on next page
3.3. ML 309
pyspark Documentation, Release master
Table 75 – continued from previous pagefitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map
in paramMaps.getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFdr(value) Sets the value of fdr.setFeaturesCol(value) Sets the value of featuresCol.setFpr(value) Sets the value of fpr.setFwe(value) Sets the value of fwe.setLabelCol(value) Sets the value of labelCol.setNumTopFeatures(value) Sets the value of numTopFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numTopFeatures, . . . ]) Sets params for this ANOVASelector.setPercentile(value) Sets the value of percentile.setSelectorType(value) Sets the value of selectorType.write() Returns an MLWriter instance for this ML instance.
Attributes
fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.
continues on next page
310 Chapter 3. API Reference
pyspark Documentation, Release master
Table 76 – continued from previous pagepercentileselectorType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
3.3. ML 311
pyspark Documentation, Release master
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFdr()Gets the value of fdr or its default value.
New in version 2.2.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFpr()Gets the value of fpr or its default value.
New in version 2.1.0.
getFwe()Gets the value of fwe or its default value.
New in version 2.2.0.
getLabelCol()Gets the value of labelCol or its default value.
getNumTopFeatures()Gets the value of numTopFeatures or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPercentile()Gets the value of percentile or its default value.
New in version 2.1.0.
getSelectorType()Gets the value of selectorType or its default value.
New in version 2.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
312 Chapter 3. API Reference
pyspark Documentation, Release master
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFdr(value)Sets the value of fdr. Only applicable when selectorType = “fdr”.
New in version 2.2.0.
setFeaturesCol(value)Sets the value of featuresCol.
setFpr(value)Sets the value of fpr. Only applicable when selectorType = “fpr”.
New in version 2.1.0.
setFwe(value)Sets the value of fwe. Only applicable when selectorType = “fwe”.
New in version 2.2.0.
setLabelCol(value)Sets the value of labelCol.
setNumTopFeatures(value)Sets the value of numTopFeatures. Only applicable when selectorType = “numTopFeatures”.
New in version 2.0.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, numTopFeatures=50, featuresCol="features", outputCol=None, label-Col="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05,fwe=0.05)
Sets params for this ANOVASelector.
New in version 3.1.0.
setPercentile(value)Sets the value of percentile. Only applicable when selectorType = “percentile”.
New in version 2.1.0.
setSelectorType(value)Sets the value of selectorType.
New in version 2.1.0.
write()Returns an MLWriter instance for this ML instance.
3.3. ML 313
pyspark Documentation, Release master
Attributes Documentation
fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')
fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')
selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')
ANOVASelectorModel
class pyspark.ml.feature.ANOVASelectorModel(java_model=None)Model fitted by ANOVASelector.
New in version 3.1.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default
value.continues on next page
314 Chapter 3. API Reference
pyspark Documentation, Release master
Table 77 – continued from previous pagegetOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectedFeatures List of indices to select (filter).selectorType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
3.3. ML 315
pyspark Documentation, Release master
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getFdr()Gets the value of fdr or its default value.
New in version 2.2.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFpr()Gets the value of fpr or its default value.
New in version 2.1.0.
getFwe()Gets the value of fwe or its default value.
New in version 2.2.0.
getLabelCol()Gets the value of labelCol or its default value.
getNumTopFeatures()Gets the value of numTopFeatures or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPercentile()Gets the value of percentile or its default value.
New in version 2.1.0.
316 Chapter 3. API Reference
pyspark Documentation, Release master
getSelectorType()Gets the value of selectorType or its default value.
New in version 2.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 317
pyspark Documentation, Release master
Attributes Documentation
fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')
fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')
selectedFeaturesList of indices to select (filter).
New in version 2.0.0.
selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')
Binarizer
class pyspark.ml.feature.Binarizer(*args, **kwargs)Binarize a column of continuous features given a threshold. Since 3.0.0, Binarize can map multiple columnsat once by setting the inputCols parameter. Note that when both the inputCol and inputCols param-eters are set, an Exception will be thrown. The threshold parameter is used for single column usage, andthresholds is for multiple columns.
New in version 1.4.0.
Examples
>>> df = spark.createDataFrame([(0.5,)], ["values"])>>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features")>>> binarizer.setThreshold(1.0)Binarizer...>>> binarizer.setInputCol("values")Binarizer...>>> binarizer.setOutputCol("features")Binarizer...>>> binarizer.transform(df).head().features0.0>>> binarizer.setParams(outputCol="freqs").transform(df).head().freqs0.0>>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"}>>> binarizer.transform(df, params).head().vector1.0>>> binarizerPath = temp_path + "/binarizer">>> binarizer.save(binarizerPath)>>> loadedBinarizer = Binarizer.load(binarizerPath)
(continues on next page)
318 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> loadedBinarizer.getThreshold() == binarizer.getThreshold()True>>> loadedBinarizer.transform(df).take(1) == binarizer.transform(df).take(1)True>>> df2 = spark.createDataFrame([(0.5, 0.3)], ["values1", "values2"])>>> binarizer2 = Binarizer(thresholds=[0.0, 1.0])>>> binarizer2.setInputCols(["values1", "values2"]).setOutputCols(["output1",→˓"output2"])Binarizer...>>> binarizer2.transform(df2).show()+-------+-------+-------+-------+|values1|values2|output1|output2|+-------+-------+-------+-------+| 0.5| 0.3| 1.0| 0.0|+-------+-------+-------+-------+...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getThreshold() Gets the value of threshold or its default value.getThresholds() Gets the value of thresholds or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.
continues on next page
3.3. ML 319
pyspark Documentation, Release master
Table 79 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, threshold, inputCol, . . . ]) Sets params for this Binarizer.setThreshold(value) Sets the value of threshold.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColinputColsoutputColoutputColsparams Returns all params ordered by name.thresholdthresholds
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
320 Chapter 3. API Reference
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getThreshold()Gets the value of threshold or its default value.
getThresholds()Gets the value of thresholds or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
3.3. ML 321
pyspark Documentation, Release master
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
setParams(self, \*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, input-Cols=None, outputCols=None)
Sets params for this Binarizer.
New in version 1.4.0.
setThreshold(value)Sets the value of threshold.
New in version 1.4.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
threshold = Param(parent='undefined', name='threshold', doc='Param for threshold used to binarize continuous features. The features greater than the threshold will be binarized to 1.0. The features equal to or less than the threshold will be binarized to 0.0')
thresholds = Param(parent='undefined', name='thresholds', doc='Param for array of threshold used to binarize continuous features. This is for multiple columns input. If transforming multiple columns and thresholds is not set, but threshold is set, then threshold will be applied across all columns.')
322 Chapter 3. API Reference
pyspark Documentation, Release master
BucketedRandomProjectionLSH
class pyspark.ml.feature.BucketedRandomProjectionLSH(*args, **kwargs)LSH class for Euclidean distance metrics. The input is dense or sparse vectors, each of which represents a pointin the Euclidean distance space. The output will be vectors of configurable dimension. Hash values in the samedimension are calculated by the same hash function.
New in version 2.2.0.
Notes
• Stable Distributions in Wikipedia article on Locality-sensitive hashing
• Hashing for Similarity Search: A Survey
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.sql.functions import col>>> data = [(0, Vectors.dense([-1.0, -1.0 ]),),... (1, Vectors.dense([-1.0, 1.0 ]),),... (2, Vectors.dense([1.0, -1.0 ]),),... (3, Vectors.dense([1.0, 1.0]),)]>>> df = spark.createDataFrame(data, ["id", "features"])>>> brp = BucketedRandomProjectionLSH()>>> brp.setInputCol("features")BucketedRandomProjectionLSH...>>> brp.setOutputCol("hashes")BucketedRandomProjectionLSH...>>> brp.setSeed(12345)BucketedRandomProjectionLSH...>>> brp.setBucketLength(1.0)BucketedRandomProjectionLSH...>>> model = brp.fit(df)>>> model.getBucketLength()1.0>>> model.setOutputCol("hashes")BucketedRandomProjectionLSHModel...>>> model.transform(df).head()Row(id=0, features=DenseVector([-1.0, -1.0]), hashes=[DenseVector([-1.0])])>>> data2 = [(4, Vectors.dense([2.0, 2.0 ]),),... (5, Vectors.dense([2.0, 3.0 ]),),... (6, Vectors.dense([3.0, 2.0 ]),),... (7, Vectors.dense([3.0, 3.0]),)]>>> df2 = spark.createDataFrame(data2, ["id", "features"])>>> model.approxNearestNeighbors(df2, Vectors.dense([1.0, 2.0]), 1).collect()[Row(id=4, features=DenseVector([2.0, 2.0]), hashes=[DenseVector([1.0])],→˓distCol=1.0)]>>> model.approxSimilarityJoin(df, df2, 3.0, distCol="EuclideanDistance").select(... col("datasetA.id").alias("idA"),... col("datasetB.id").alias("idB"),... col("EuclideanDistance")).show()+---+---+-----------------+|idA|idB|EuclideanDistance|+---+---+-----------------+| 3| 6| 2.23606797749979|
(continues on next page)
3.3. ML 323
pyspark Documentation, Release master
(continued from previous page)
+---+---+-----------------+...>>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select(... col("datasetA.id").alias("idA"),... col("datasetB.id").alias("idB"),... col("EuclideanDistance")).show()+---+---+-----------------+|idA|idB|EuclideanDistance|+---+---+-----------------+| 3| 6| 2.23606797749979|+---+---+-----------------+...>>> brpPath = temp_path + "/brp">>> brp.save(brpPath)>>> brp2 = BucketedRandomProjectionLSH.load(brpPath)>>> brp2.getBucketLength() == brp.getBucketLength()True>>> modelPath = temp_path + "/brp-model">>> model.save(modelPath)>>> model2 = BucketedRandomProjectionLSHModel.load(modelPath)>>> model.transform(df).head().hashes == model2.transform(df).head().hashesTrue
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getBucketLength() Gets the value of bucketLength or its default value.getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.
continues on next page
324 Chapter 3. API Reference
pyspark Documentation, Release master
Table 81 – continued from previous pagegetSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBucketLength(value) Sets the value of bucketLength.setInputCol(value) Sets the value of inputCol.setNumHashTables(value) Sets the value of numHashTables.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this BucketedRandomProjec-
tionLSH.setSeed(value) Sets the value of seed.write() Returns an MLWriter instance for this ML instance.
Attributes
bucketLengthinputColnumHashTablesoutputColparams Returns all params ordered by name.seed
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
3.3. ML 325
pyspark Documentation, Release master
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getBucketLength()Gets the value of bucketLength or its default value.
New in version 2.2.0.
getInputCol()Gets the value of inputCol or its default value.
getNumHashTables()Gets the value of numHashTables or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
326 Chapter 3. API Reference
pyspark Documentation, Release master
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBucketLength(value)Sets the value of bucketLength.
New in version 2.2.0.
setInputCol(value)Sets the value of inputCol.
setNumHashTables(value)Sets the value of numHashTables.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inputCol=None, outputCol=None, seed=None, numHashTables=1, buck-etLength=None)
Sets params for this BucketedRandomProjectionLSH.
New in version 2.2.0.
setSeed(value)Sets the value of seed.
write()Returns an MLWriter instance for this ML instance.
3.3. ML 327
pyspark Documentation, Release master
Attributes Documentation
bucketLength = Param(parent='undefined', name='bucketLength', doc='the length of each hash bucket, a larger bucket lowers the false negative rate.')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
BucketedRandomProjectionLSHModel
class pyspark.ml.feature.BucketedRandomProjectionLSHModel(java_model=None)Model fitted by BucketedRandomProjectionLSH , where multiple random vectors are stored. The vec-tors are normalized to be unit vectors and each vector is used in a hash function: ℎ𝑖(𝑥) = 𝑓𝑙𝑜𝑜𝑟(𝑟𝑖 ·𝑥/𝑏𝑢𝑐𝑘𝑒𝑡𝐿𝑒𝑛𝑔𝑡ℎ) where 𝑟𝑖 is the i-th random unit vector. The number of buckets will be (max L2 norm ofinput vectors) / bucketLength.
New in version 2.2.0.
Methods
approxNearestNeighbors(dataset, key, . . . [,. . . ])
Given a large dataset and an item, approximately findat most k items which have the closest distance to theitem.
approxSimilarityJoin(datasetA, datasetB,. . . )
Join two datasets to approximately find all pairs ofrows whose distance are smaller than the threshold.
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBucketLength() Gets the value of bucketLength or its default value.getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.continues on next page
328 Chapter 3. API Reference
pyspark Documentation, Release master
Table 83 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
bucketLengthinputColnumHashTablesoutputColparams Returns all params ordered by name.
Methods Documentation
approxNearestNeighbors(dataset, key, numNearestNeighbors, distCol='distCol')Given a large dataset and an item, approximately find at most k items which have the closest distance tothe item. If the outputCol is missing, the method will transform the data; if the outputCol exists, itwill use that. This allows caching of the transformed data when necessary.
Parameters
dataset [pyspark.sql.DataFrame] The dataset to search for nearest neighbors of thekey.
key [pyspark.ml.linalg.Vector] Feature vector representing the item to search for.
numNearestNeighbors [int] The maximum number of nearest neighbors.
distCol [str] Output column for storing the distance between each result row and the key.Use “distCol” as default value if it’s not specified.
Returns
pyspark.sql.DataFrame A dataset containing at most k items closest to the key. Acolumn “distCol” is added to show the distance between each row and the key.
3.3. ML 329
pyspark Documentation, Release master
Notes
This method is experimental and will likely change behavior in the next release.
approxSimilarityJoin(datasetA, datasetB, threshold, distCol='distCol')Join two datasets to approximately find all pairs of rows whose distance are smaller than the threshold. Ifthe outputCol is missing, the method will transform the data; if the outputCol exists, it will use that.This allows caching of the transformed data when necessary.
Parameters
datasetA [pyspark.sql.DataFrame] One of the datasets to join.
datasetB [pyspark.sql.DataFrame] Another dataset to join.
threshold [float] The threshold for the distance of row pairs.
distCol [str, optional] Output column for storing the distance between each pair of rows.Use “distCol” as default value if it’s not specified.
Returns
pyspark.sql.DataFrame A joined dataset containing pairs of rows. The original rowsare in columns “datasetA” and “datasetB”, and a column “distCol” is added to show thedistance between each pair.
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBucketLength()Gets the value of bucketLength or its default value.
330 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 2.2.0.
getInputCol()Gets the value of inputCol or its default value.
getNumHashTables()Gets the value of numHashTables or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 331
pyspark Documentation, Release master
Attributes Documentation
bucketLength = Param(parent='undefined', name='bucketLength', doc='the length of each hash bucket, a larger bucket lowers the false negative rate.')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Bucketizer
class pyspark.ml.feature.Bucketizer(*args, **kwargs)Maps a column of continuous features to a column of feature buckets. Since 3.0.0, Bucketizer can mapmultiple columns at once by setting the inputCols parameter. Note that when both the inputCol andinputCols parameters are set, an Exception will be thrown. The splits parameter is only used for singlecolumn usage, and splitsArray is for multiple columns.
New in version 1.4.0.
Examples
>>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),... (float("nan"), 1.0), (float("nan"), 0.0)]>>> df = spark.createDataFrame(values, ["values1", "values2"])>>> bucketizer = Bucketizer()>>> bucketizer.setSplits([-float("inf"), 0.5, 1.4, float("inf")])Bucketizer...>>> bucketizer.setInputCol("values1")Bucketizer...>>> bucketizer.setOutputCol("buckets")Bucketizer...>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))>>> bucketed.show(truncate=False)+-------+-------+|values1|buckets|+-------+-------+|0.1 |0.0 ||0.4 |0.0 ||1.2 |1.0 ||1.5 |2.0 ||NaN |3.0 ||NaN |3.0 |+-------+-------+...>>> bucketizer.setParams(outputCol="b").transform(df).head().b0.0>>> bucketizerPath = temp_path + "/bucketizer">>> bucketizer.save(bucketizerPath)>>> loadedBucketizer = Bucketizer.load(bucketizerPath)>>> loadedBucketizer.getSplits() == bucketizer.getSplits()
(continues on next page)
332 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
True>>> loadedBucketizer.transform(df).take(1) == bucketizer.transform(df).take(1)True>>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()>>> len(bucketed)4>>> bucketizer2 = Bucketizer(splitsArray=... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf→˓")]],... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])>>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)>>> bucketed2.show(truncate=False)+-------+-------+--------+--------+|values1|values2|buckets1|buckets2|+-------+-------+--------+--------+|0.1 |0.0 |0.0 |0.0 ||0.4 |1.0 |0.0 |1.0 ||1.2 |1.3 |1.0 |1.0 ||1.5 |NaN |2.0 |2.0 ||NaN |1.0 |3.0 |1.0 ||NaN |0.0 |3.0 |0.0 |+-------+-------+--------+--------+...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getSplits() Gets the value of threshold or its default value.getSplitsArray() Gets the array of split points or its default value.
continues on next page
3.3. ML 333
pyspark Documentation, Release master
Table 85 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, splits, inputCol, . . . ]) Sets params for this Bucketizer.setSplits(value) Sets the value of splits.setSplitsArray(value) Sets the value of splitsArray .transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
handleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.splitssplitsArray
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
334 Chapter 3. API Reference
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getSplits()Gets the value of threshold or its default value.
New in version 1.4.0.
getSplitsArray()Gets the array of split points or its default value.
New in version 3.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
3.3. ML 335
pyspark Documentation, Release master
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setHandleInvalid(value)Sets the value of handleInvalid.
setInputCol(value)Sets the value of inputCol.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
setParams(self, \*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", splitsAr-ray=None, inputCols=None, outputCols=None)
Sets params for this Bucketizer.
New in version 1.4.0.
setSplits(value)Sets the value of splits.
New in version 1.4.0.
setSplitsArray(value)Sets the value of splitsArray .
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
336 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries containing NaN values. Values outside the splits will always be treated as errors. Options are 'skip' (filter out rows with invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special additional bucket). Note that in the multiple column case, the invalid handling is applied to all columns. That said for 'error' it will throw an error if any invalids are found in any column, for 'skip' it will skip rows with any invalids in any columns, etc.")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
splits = Param(parent='undefined', name='splits', doc='Split points for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, values outside the splits specified will be treated as errors.')
splitsArray = Param(parent='undefined', name='splitsArray', doc='The array of split points for mapping continuous features into buckets for multiple columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, values outside the splits specified will be treated as errors.')
ChiSqSelector
class pyspark.ml.feature.ChiSqSelector(*args, **kwargs)Chi-Squared feature selection, which selects categorical features to use for predicting a categorical label. Theselector supports different selection methods: numTopFeatures, percentile, fpr, fdr, fwe.
• numTopFeatures chooses a fixed number of top features according to a chi-squared test.
• percentile is similar but chooses a fraction of all features instead of a fixed number.
• fpr chooses all features whose p-values are below a threshold, thus controlling the false positive rate ofselection.
• fdr uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below athreshold.
• fwe chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures,thus controlling the family-wise error rate of selection.
By default, the selection method is numTopFeatures, with the default number of top features set to 50.
New in version 2.0.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],... ["features", "label"])>>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")ChiSqSelectorModel...>>> model.transform(df).head().selectedFeaturesDenseVector([18.0])>>> model.selectedFeatures
(continues on next page)
3.3. ML 337
pyspark Documentation, Release master
(continued from previous page)
[2]>>> chiSqSelectorPath = temp_path + "/chi-sq-selector">>> selector.save(chiSqSelectorPath)>>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath)>>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()True>>> modelPath = temp_path + "/chi-sq-selector-model">>> model.save(modelPath)>>> loadedModel = ChiSqSelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.
continues on next page
338 Chapter 3. API Reference
pyspark Documentation, Release master
Table 87 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFdr(value) Sets the value of fdr.setFeaturesCol(value) Sets the value of featuresCol.setFpr(value) Sets the value of fpr.setFwe(value) Sets the value of fwe.setLabelCol(value) Sets the value of labelCol.setNumTopFeatures(value) Sets the value of numTopFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numTopFeatures, . . . ]) Sets params for this ChiSqSelector.setPercentile(value) Sets the value of percentile.setSelectorType(value) Sets the value of selectorType.write() Returns an MLWriter instance for this ML instance.
Attributes
fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectorType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
3.3. ML 339
pyspark Documentation, Release master
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFdr()Gets the value of fdr or its default value.
New in version 2.2.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFpr()Gets the value of fpr or its default value.
340 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 2.1.0.
getFwe()Gets the value of fwe or its default value.
New in version 2.2.0.
getLabelCol()Gets the value of labelCol or its default value.
getNumTopFeatures()Gets the value of numTopFeatures or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPercentile()Gets the value of percentile or its default value.
New in version 2.1.0.
getSelectorType()Gets the value of selectorType or its default value.
New in version 2.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFdr(value)Sets the value of fdr. Only applicable when selectorType = “fdr”.
New in version 2.2.0.
3.3. ML 341
pyspark Documentation, Release master
setFeaturesCol(value)Sets the value of featuresCol.
setFpr(value)Sets the value of fpr. Only applicable when selectorType = “fpr”.
New in version 2.1.0.
setFwe(value)Sets the value of fwe. Only applicable when selectorType = “fwe”.
New in version 2.2.0.
setLabelCol(value)Sets the value of labelCol.
setNumTopFeatures(value)Sets the value of numTopFeatures. Only applicable when selectorType = “numTopFeatures”.
New in version 2.0.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, numTopFeatures=50, featuresCol="features", outputCol=None, label-Col="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05,fwe=0.05)
Sets params for this ChiSqSelector.
New in version 2.0.0.
setPercentile(value)Sets the value of percentile. Only applicable when selectorType = “percentile”.
New in version 2.1.0.
setSelectorType(value)Sets the value of selectorType.
New in version 2.1.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')
fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')
342 Chapter 3. API Reference
pyspark Documentation, Release master
selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')
ChiSqSelectorModel
class pyspark.ml.feature.ChiSqSelectorModel(java_model=None)Model fitted by ChiSqSelector.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.continues on next page
3.3. ML 343
pyspark Documentation, Release master
Table 89 – continued from previous pageset(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectedFeatures List of indices to select (filter).selectorType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
344 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
dict merged param map
getFdr()Gets the value of fdr or its default value.
New in version 2.2.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFpr()Gets the value of fpr or its default value.
New in version 2.1.0.
getFwe()Gets the value of fwe or its default value.
New in version 2.2.0.
getLabelCol()Gets the value of labelCol or its default value.
getNumTopFeatures()Gets the value of numTopFeatures or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPercentile()Gets the value of percentile or its default value.
New in version 2.1.0.
getSelectorType()Gets the value of selectorType or its default value.
New in version 2.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 345
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')
fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')
selectedFeaturesList of indices to select (filter).
New in version 2.0.0.
selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')
346 Chapter 3. API Reference
pyspark Documentation, Release master
CountVectorizer
class pyspark.ml.feature.CountVectorizer(*args, **kwargs)Extracts a vocabulary from document collections and generates a CountVectorizerModel.
New in version 1.6.0.
Examples
>>> df = spark.createDataFrame(... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],... ["label", "raw"])>>> cv = CountVectorizer()>>> cv.setInputCol("raw")CountVectorizer...>>> cv.setOutputCol("vectors")CountVectorizer...>>> model = cv.fit(df)>>> model.setInputCol("raw")CountVectorizerModel...>>> model.transform(df).show(truncate=False)+-----+---------------+-------------------------+|label|raw |vectors |+-----+---------------+-------------------------+|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])||1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|+-----+---------------+-------------------------+...>>> sorted(model.vocabulary) == ['a', 'b', 'c']True>>> countVectorizerPath = temp_path + "/count-vectorizer">>> cv.save(countVectorizerPath)>>> loadedCv = CountVectorizer.load(countVectorizerPath)>>> loadedCv.getMinDF() == cv.getMinDF()True>>> loadedCv.getMinTF() == cv.getMinTF()True>>> loadedCv.getVocabSize() == cv.getVocabSize()True>>> modelPath = temp_path + "/count-vectorizer-model">>> model.save(modelPath)>>> loadedModel = CountVectorizerModel.load(modelPath)>>> loadedModel.vocabulary == model.vocabularyTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True>>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],... inputCol="raw", outputCol="vectors")>>> fromVocabModel.transform(df).show(truncate=False)+-----+---------------+-------------------------+|label|raw |vectors |+-----+---------------+-------------------------+|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])||1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|+-----+---------------+-------------------------+...
3.3. ML 347
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getBinary() Gets the value of binary or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxDF() Gets the value of maxDF or its default value.getMinDF() Gets the value of minDF or its default value.getMinTF() Gets the value of minTF or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getVocabSize() Gets the value of vocabSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBinary(value) Sets the value of binary .setInputCol(value) Sets the value of inputCol.setMaxDF(value) Sets the value of maxDF.setMinDF(value) Sets the value of minDF.setMinTF(value) Sets the value of minTF.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minTF, minDF, maxDF, . . . ]) Set the params for the CountVectorizersetVocabSize(value) Sets the value of vocabSize.
continues on next page
348 Chapter 3. API Reference
pyspark Documentation, Release master
Table 91 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.
Attributes
binaryinputColmaxDFminDFminTFoutputColparams Returns all params ordered by name.vocabSize
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
3.3. ML 349
pyspark Documentation, Release master
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getBinary()Gets the value of binary or its default value.
New in version 2.0.0.
getInputCol()Gets the value of inputCol or its default value.
getMaxDF()Gets the value of maxDF or its default value.
New in version 2.4.0.
getMinDF()Gets the value of minDF or its default value.
New in version 1.6.0.
getMinTF()Gets the value of minTF or its default value.
New in version 1.6.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getVocabSize()Gets the value of vocabSize or its default value.
New in version 1.6.0.
350 Chapter 3. API Reference
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBinary(value)Sets the value of binary .
New in version 2.0.0.
setInputCol(value)Sets the value of inputCol.
setMaxDF(value)Sets the value of maxDF.
New in version 2.4.0.
setMinDF(value)Sets the value of minDF.
New in version 1.6.0.
setMinTF(value)Sets the value of minTF.
New in version 1.6.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,inputCol=None, outputCol=None)
Set the params for the CountVectorizer
New in version 1.6.0.
setVocabSize(value)Sets the value of vocabSize.
New in version 1.6.0.
write()Returns an MLWriter instance for this ML instance.
3.3. ML 351
pyspark Documentation, Release master
Attributes Documentation
binary = Param(parent='undefined', name='binary', doc='Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. Default False')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxDF = Param(parent='undefined', name='maxDF', doc='Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. Default (2^63) - 1')
minDF = Param(parent='undefined', name='minDF', doc='Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents. Default 1.0')
minTF = Param(parent='undefined', name='minTF', doc="Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). Note that the parameter is only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0")
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
vocabSize = Param(parent='undefined', name='vocabSize', doc='max size of the vocabulary. Default 1 << 18.')
CountVectorizerModel
class pyspark.ml.feature.CountVectorizerModel(java_model=None)Model fitted by CountVectorizer.
New in version 1.6.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
from_vocabulary(vocabulary, inputCol[, . . . ]) Construct the model directly from a vocabulary listof strings, requires an active SparkContext.
getBinary() Gets the value of binary or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxDF() Gets the value of maxDF or its default value.getMinDF() Gets the value of minDF or its default value.getMinTF() Gets the value of minTF or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.
continues on next page
352 Chapter 3. API Reference
pyspark Documentation, Release master
Table 93 – continued from previous pagegetParam(paramName) Gets a param by its name.getVocabSize() Gets the value of vocabSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBinary(value) Sets the value of binary .setInputCol(value) Sets the value of inputCol.setMinTF(value) Sets the value of minTF.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
binaryinputColmaxDFminDFminTFoutputColparams Returns all params ordered by name.vocabSizevocabulary An array of terms in the vocabulary.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
3.3. ML 353
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
classmethod from_vocabulary(vocabulary, inputCol, outputCol=None, minTF=None, bi-nary=None)
Construct the model directly from a vocabulary list of strings, requires an active SparkContext.
New in version 2.4.0.
getBinary()Gets the value of binary or its default value.
New in version 2.0.0.
getInputCol()Gets the value of inputCol or its default value.
getMaxDF()Gets the value of maxDF or its default value.
New in version 2.4.0.
getMinDF()Gets the value of minDF or its default value.
New in version 1.6.0.
getMinTF()Gets the value of minTF or its default value.
New in version 1.6.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getVocabSize()Gets the value of vocabSize or its default value.
New in version 1.6.0.
hasDefault(param)Checks whether a param has a default value.
354 Chapter 3. API Reference
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBinary(value)Sets the value of binary .
New in version 2.4.0.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setMinTF(value)Sets the value of minTF.
New in version 2.4.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 355
pyspark Documentation, Release master
Attributes Documentation
binary = Param(parent='undefined', name='binary', doc='Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. Default False')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxDF = Param(parent='undefined', name='maxDF', doc='Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. Default (2^63) - 1')
minDF = Param(parent='undefined', name='minDF', doc='Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents. Default 1.0')
minTF = Param(parent='undefined', name='minTF', doc="Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). Note that the parameter is only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0")
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
vocabSize = Param(parent='undefined', name='vocabSize', doc='max size of the vocabulary. Default 1 << 18.')
vocabularyAn array of terms in the vocabulary.
New in version 1.6.0.
DCT
class pyspark.ml.feature.DCT(*args, **kwargs)A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero padding is performedon the input vector. It returns a real vector of the same length representing the DCT. The return vector is scaledsuch that the transform matrix is unitary (aka scaled DCT-II).
New in version 1.6.0.
Notes
More information on Wikipedia.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])>>> dct = DCT( )>>> dct.setInverse(False)DCT...>>> dct.setInputCol("vec")DCT...>>> dct.setOutputCol("resultVec")DCT...>>> df2 = dct.transform(df1)>>> df2.head().resultVecDenseVector([10.969..., -0.707..., -2.041...])>>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").→˓transform(df2)>>> df3.head().origVecDenseVector([5.0, 8.0, 6.0])>>> dctPath = temp_path + "/dct"
(continues on next page)
356 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> dct.save(dctPath)>>> loadedDtc = DCT.load(dctPath)>>> loadedDtc.transform(df1).take(1) == dct.transform(df1).take(1)True>>> loadedDtc.getInverse()False
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getInverse() Gets the value of inverse or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInverse(value) Sets the value of inverse.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inverse, inputCol, . . . ]) Sets params for this DCT.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
3.3. ML 357
pyspark Documentation, Release master
Attributes
inputColinverseoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getInverse()Gets the value of inverse or its default value.
New in version 1.6.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
358 Chapter 3. API Reference
pyspark Documentation, Release master
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setInverse(value)Sets the value of inverse.
New in version 1.6.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inverse=False, inputCol=None, outputCol=None)Sets params for this DCT.
New in version 1.6.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 359
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inverse = Param(parent='undefined', name='inverse', doc='Set transformer to perform inverse DCT, default False.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
ElementwiseProduct
class pyspark.ml.feature.ElementwiseProduct(*args, **kwargs)Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a provided “weight”vector. In other words, it scales each column of the dataset by a scalar multiplier.
New in version 1.5.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"])>>> ep = ElementwiseProduct()>>> ep.setScalingVec(Vectors.dense([1.0, 2.0, 3.0]))ElementwiseProduct...>>> ep.setInputCol("values")ElementwiseProduct...>>> ep.setOutputCol("eprod")ElementwiseProduct...>>> ep.transform(df).head().eprodDenseVector([2.0, 2.0, 9.0])>>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().→˓eprodDenseVector([4.0, 3.0, 15.0])>>> elementwiseProductPath = temp_path + "/elementwise-product">>> ep.save(elementwiseProductPath)>>> loadedEp = ElementwiseProduct.load(elementwiseProductPath)>>> loadedEp.getScalingVec() == ep.getScalingVec()True>>> loadedEp.transform(df).take(1) == ep.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
continues on next page
360 Chapter 3. API Reference
pyspark Documentation, Release master
Table 97 – continued from previous pageexplainParams() Returns the documentation of all params with their
optionally default values and user-supplied values.extractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getScalingVec() Gets the value of scalingVec or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, scalingVec, inputCol, . . . ]) Sets params for this ElementwiseProduct.setScalingVec(value) Sets the value of scalingVec.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColoutputColparams Returns all params ordered by name.scalingVec
3.3. ML 361
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getScalingVec()Gets the value of scalingVec or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
362 Chapter 3. API Reference
pyspark Documentation, Release master
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, scalingVec=None, inputCol=None, outputCol=None)Sets params for this ElementwiseProduct.
New in version 1.5.0.
setScalingVec(value)Sets the value of scalingVec.
New in version 2.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
scalingVec = Param(parent='undefined', name='scalingVec', doc='Vector for hadamard product.')
3.3. ML 363
pyspark Documentation, Release master
FeatureHasher
class pyspark.ml.feature.FeatureHasher(*args, **kwargs)Feature hashing projects a set of categorical or numerical features into a feature vector of specified dimension(typically substantially smaller than that of the original feature space). This is done using the hashing trick(https://en.wikipedia.org/wiki/Feature_hashing) to map features to indices in the feature vector.
The FeatureHasher transformer operates on multiple columns. Each column may contain either numeric orcategorical features. Behavior and handling of column data types is as follows:
• Numeric columns: For numeric features, the hash value of the column name is used to map the featurevalue to its index in the feature vector. By default, numeric features are not treated as categorical (evenwhen they are integers). To treat them as categorical, specify the relevant columns in categoricalCols.
• String columns: For categorical features, the hash value of the string “column_name=value” is used tomap to the vector index, with an indicator value of 1.0. Thus, categorical features are “one-hot”encoded (similarly to using OneHotEncoder with dropLast=false).
• Boolean columns: Boolean values are treated in the same way as string columns. That is, boolean featuresare represented as “column_name=true” or “column_name=false”, with an indicator value of 1.0.
Null (missing) values are ignored (implicitly zero in the resulting feature vector).
Since a simple modulo is used to transform the hash function to a vector index, it is advisable to use a power oftwo as the numFeatures parameter; otherwise the features will not be mapped evenly to the vector indices.
New in version 2.3.0.
Examples
>>> data = [(2.0, True, "1", "foo"), (3.0, False, "2", "bar")]>>> cols = ["real", "bool", "stringNum", "string"]>>> df = spark.createDataFrame(data, cols)>>> hasher = FeatureHasher()>>> hasher.setInputCols(cols)FeatureHasher...>>> hasher.setOutputCol("features")FeatureHasher...>>> hasher.transform(df).head().featuresSparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})>>> hasher.setCategoricalCols(["real"]).transform(df).head().featuresSparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})>>> hasherPath = temp_path + "/hasher">>> hasher.save(hasherPath)>>> loadedHasher = FeatureHasher.load(hasherPath)>>> loadedHasher.getNumFeatures() == hasher.getNumFeatures()True>>> loadedHasher.transform(df).head().features == hasher.transform(df).head().→˓featuresTrue
364 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCategoricalCols() Gets the value of binary or its default value.getInputCols() Gets the value of inputCols or its default value.getNumFeatures() Gets the value of numFeatures or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCategoricalCols(value) Sets the value of categoricalCols.setInputCols(value) Sets the value of inputCols.setNumFeatures(value) Sets the value of numFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numFeatures, . . . ]) Sets params for this FeatureHasher.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
3.3. ML 365
pyspark Documentation, Release master
Attributes
categoricalColsinputColsnumFeaturesoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCategoricalCols()Gets the value of binary or its default value.
New in version 2.3.0.
getInputCols()Gets the value of inputCols or its default value.
getNumFeatures()Gets the value of numFeatures or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
366 Chapter 3. API Reference
pyspark Documentation, Release master
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCategoricalCols(value)Sets the value of categoricalCols.
New in version 2.3.0.
setInputCols(value)Sets the value of inputCols.
setNumFeatures(value)Sets the value of numFeatures.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, numFeatures=1 << 18, inputCols=None, outputCol=None, categorical-Cols=None)
Sets params for this FeatureHasher.
New in version 2.3.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
3.3. ML 367
pyspark Documentation, Release master
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
categoricalCols = Param(parent='undefined', name='categoricalCols', doc='numeric columns to treat as categorical')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
numFeatures = Param(parent='undefined', name='numFeatures', doc='Number of features. Should be greater than 0.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
FValueSelector
class pyspark.ml.feature.FValueSelector(*args, **kwargs)F Value Regression feature selector, which selects continuous features to use for predicting a continuous label.The selector supports different selection methods: numTopFeatures, percentile, fpr, fdr, fwe.
• numTopFeatures chooses a fixed number of top features according to a F value regression test.
• percentile is similar but chooses a fraction of all features instead of a fixed number.
• fpr chooses all features whose p-values are below a threshold, thus controlling the false positive rate ofselection.
• fdr uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below athreshold.
• fwe chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures,thus controlling the family-wise error rate of selection.
By default, the selection method is numTopFeatures, with the default number of top features set to 50.
New in version 3.1.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]), 4.6),... (Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]), 6.6),... (Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]), 5.1),... (Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]), 7.6),... (Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]), 9.0),... (Vectors.dense([8.0, 9.0, 6.0, 4.0, 0.0, 0.0]), 9.0)],... ["features", "label"])>>> selector = FValueSelector(numTopFeatures=1, outputCol="selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")FValueSelectorModel...>>> model.transform(df).head().selectedFeatures
(continues on next page)
368 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
DenseVector([0.0])>>> model.selectedFeatures[2]>>> fvalueSelectorPath = temp_path + "/fvalue-selector">>> selector.save(fvalueSelectorPath)>>> loadedSelector = FValueSelector.load(fvalueSelectorPath)>>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()True>>> modelPath = temp_path + "/fvalue-selector-model">>> model.save(modelPath)>>> loadedModel = FValueSelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.
continues on next page
3.3. ML 369
pyspark Documentation, Release master
Table 101 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFdr(value) Sets the value of fdr.setFeaturesCol(value) Sets the value of featuresCol.setFpr(value) Sets the value of fpr.setFwe(value) Sets the value of fwe.setLabelCol(value) Sets the value of labelCol.setNumTopFeatures(value) Sets the value of numTopFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numTopFeatures, . . . ]) Sets params for this FValueSelector.setPercentile(value) Sets the value of percentile.setSelectorType(value) Sets the value of selectorType.write() Returns an MLWriter instance for this ML instance.
Attributes
fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectorType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
370 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFdr()Gets the value of fdr or its default value.
New in version 2.2.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
3.3. ML 371
pyspark Documentation, Release master
getFpr()Gets the value of fpr or its default value.
New in version 2.1.0.
getFwe()Gets the value of fwe or its default value.
New in version 2.2.0.
getLabelCol()Gets the value of labelCol or its default value.
getNumTopFeatures()Gets the value of numTopFeatures or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPercentile()Gets the value of percentile or its default value.
New in version 2.1.0.
getSelectorType()Gets the value of selectorType or its default value.
New in version 2.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFdr(value)Sets the value of fdr. Only applicable when selectorType = “fdr”.
372 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 2.2.0.
setFeaturesCol(value)Sets the value of featuresCol.
setFpr(value)Sets the value of fpr. Only applicable when selectorType = “fpr”.
New in version 2.1.0.
setFwe(value)Sets the value of fwe. Only applicable when selectorType = “fwe”.
New in version 2.2.0.
setLabelCol(value)Sets the value of labelCol.
setNumTopFeatures(value)Sets the value of numTopFeatures. Only applicable when selectorType = “numTopFeatures”.
New in version 2.0.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, numTopFeatures=50, featuresCol="features", outputCol=None, label-Col="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05,fwe=0.05)
Sets params for this FValueSelector.
New in version 3.1.0.
setPercentile(value)Sets the value of percentile. Only applicable when selectorType = “percentile”.
New in version 2.1.0.
setSelectorType(value)Sets the value of selectorType.
New in version 2.1.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')
fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
3.3. ML 373
pyspark Documentation, Release master
percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')
selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')
FValueSelectorModel
class pyspark.ml.feature.FValueSelectorModel(java_model=None)Model fitted by FValueSelector.
New in version 3.1.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.
continues on next page
374 Chapter 3. API Reference
pyspark Documentation, Release master
Table 103 – continued from previous pagesave(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectedFeatures List of indices to select (filter).selectorType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
3.3. ML 375
pyspark Documentation, Release master
extra [dict, optional] extra param values
Returns
dict merged param map
getFdr()Gets the value of fdr or its default value.
New in version 2.2.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFpr()Gets the value of fpr or its default value.
New in version 2.1.0.
getFwe()Gets the value of fwe or its default value.
New in version 2.2.0.
getLabelCol()Gets the value of labelCol or its default value.
getNumTopFeatures()Gets the value of numTopFeatures or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPercentile()Gets the value of percentile or its default value.
New in version 2.1.0.
getSelectorType()Gets the value of selectorType or its default value.
New in version 2.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
376 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')
fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')
selectedFeaturesList of indices to select (filter).
New in version 2.0.0.
selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')
3.3. ML 377
pyspark Documentation, Release master
HashingTF
class pyspark.ml.feature.HashingTF(*args, **kwargs)Maps a sequence of terms to their term frequencies using the hashing trick. Currently we use Austin Appleby’sMurmurHash 3 algorithm (MurmurHash3_x86_32) to calculate the hash code value for the term object. Since asimple modulo is used to transform the hash function to a column index, it is advisable to use a power of two asthe numFeatures parameter; otherwise the features will not be mapped evenly to the columns.
New in version 1.3.0.
Examples
>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"])>>> hashingTF = HashingTF(inputCol="words", outputCol="features")>>> hashingTF.setNumFeatures(10)HashingTF...>>> hashingTF.transform(df).head().featuresSparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqsSparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}>>> hashingTF.transform(df, params).head().vectorSparseVector(5, {0: 1.0, 2: 1.0, 3: 1.0})>>> hashingTFPath = temp_path + "/hashing-tf">>> hashingTF.save(hashingTFPath)>>> loadedHashingTF = HashingTF.load(hashingTFPath)>>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures()True>>> loadedHashingTF.transform(df).take(1) == hashingTF.transform(df).take(1)True>>> hashingTF.indexOf("b")5
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBinary() Gets the value of binary or its default value.getInputCol() Gets the value of inputCol or its default value.
continues on next page
378 Chapter 3. API Reference
pyspark Documentation, Release master
Table 105 – continued from previous pagegetNumFeatures() Gets the value of numFeatures or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.indexOf(term) Returns the index of the input term.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBinary(value) Sets the value of binary .setInputCol(value) Sets the value of inputCol.setNumFeatures(value) Sets the value of numFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numFeatures, binary, . . . ]) Sets params for this HashingTF.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
binaryinputColnumFeaturesoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
3.3. ML 379
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBinary()Gets the value of binary or its default value.
New in version 2.0.0.
getInputCol()Gets the value of inputCol or its default value.
getNumFeatures()Gets the value of numFeatures or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
indexOf(term)Returns the index of the input term.
New in version 3.0.0.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
380 Chapter 3. API Reference
pyspark Documentation, Release master
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBinary(value)Sets the value of binary .
New in version 2.0.0.
setInputCol(value)Sets the value of inputCol.
setNumFeatures(value)Sets the value of numFeatures.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None)Sets params for this HashingTF.
New in version 1.3.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
binary = Param(parent='undefined', name='binary', doc='If True, all non zero counts are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. Default False.')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
numFeatures = Param(parent='undefined', name='numFeatures', doc='Number of features. Should be greater than 0.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
3.3. ML 381
pyspark Documentation, Release master
IDF
class pyspark.ml.feature.IDF(*args, **kwargs)Compute the Inverse Document Frequency (IDF) given a collection of documents.
New in version 1.4.0.
Examples
>>> from pyspark.ml.linalg import DenseVector>>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),),... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"])>>> idf = IDF(minDocFreq=3)>>> idf.setInputCol("tf")IDF...>>> idf.setOutputCol("idf")IDF...>>> model = idf.fit(df)>>> model.setOutputCol("idf")IDFModel...>>> model.getMinDocFreq()3>>> model.idfDenseVector([0.0, 0.0])>>> model.docFreq[0, 3]>>> model.numDocs == df.count()True>>> model.transform(df).head().idfDenseVector([0.0, 0.0])>>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqsDenseVector([0.0, 0.0])>>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"}>>> idf.fit(df, params).transform(df).head().vectorDenseVector([0.2877, 0.0])>>> idfPath = temp_path + "/idf">>> idf.save(idfPath)>>> loadedIdf = IDF.load(idfPath)>>> loadedIdf.getMinDocFreq() == idf.getMinDocFreq()True>>> modelPath = temp_path + "/idf-model">>> model.save(modelPath)>>> loadedModel = IDFModel.load(modelPath)>>> loadedModel.transform(df).head().idf == model.transform(df).head().idfTrue
382 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getMinDocFreq() Gets the value of minDocFreq or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMinDocFreq(value) Sets the value of minDocFreq .setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minDocFreq, inputCol, . . . ]) Sets params for this IDF.write() Returns an MLWriter instance for this ML instance.
3.3. ML 383
pyspark Documentation, Release master
Attributes
inputColminDocFreqoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
384 Chapter 3. API Reference
pyspark Documentation, Release master
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getMinDocFreq()Gets the value of minDocFreq or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setMinDocFreq(value)Sets the value of minDocFreq .
3.3. ML 385
pyspark Documentation, Release master
New in version 1.4.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, minDocFreq=0, inputCol=None, outputCol=None)Sets params for this IDF.
New in version 1.4.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
minDocFreq = Param(parent='undefined', name='minDocFreq', doc='minimum number of documents in which a term should appear for filtering')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
IDFModel
class pyspark.ml.feature.IDFModel(java_model=None)Model fitted by IDF.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getMinDocFreq() Gets the value of minDocFreq or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.
continues on next page
386 Chapter 3. API Reference
pyspark Documentation, Release master
Table 109 – continued from previous pagegetParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
docFreq Returns the document frequency.idf Returns the IDF vector.inputColminDocFreqnumDocs Returns number of documents evaluated to compute
idfoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
3.3. ML 387
pyspark Documentation, Release master
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getMinDocFreq()Gets the value of minDocFreq or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
388 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
docFreqReturns the document frequency.
New in version 3.0.0.
idfReturns the IDF vector.
New in version 2.0.0.
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
minDocFreq = Param(parent='undefined', name='minDocFreq', doc='minimum number of documents in which a term should appear for filtering')
numDocsReturns number of documents evaluated to compute idf
New in version 3.0.0.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Imputer
class pyspark.ml.feature.Imputer(*args, **kwargs)Imputation estimator for completing missing values, either using the mean or the median of the columns inwhich the missing values are located. The input columns should be of numeric type. Currently Imputer does notsupport categorical features and possibly creates incorrect values for a categorical feature.
Note that the mean/median/mode value is computed after filtering out missing values. All Null values in theinput columns are treated as missing, and so are also imputed. For computing median, pyspark.sql.DataFrame.approxQuantile() is used with a relative error of 0.001.
New in version 2.2.0.
3.3. ML 389
pyspark Documentation, Release master
Examples
>>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float(→˓"nan"), 3.0),... (4.0, 4.0), (5.0, 5.0)], ["a", "b"])>>> imputer = Imputer()>>> imputer.setInputCols(["a", "b"])Imputer...>>> imputer.setOutputCols(["out_a", "out_b"])Imputer...>>> imputer.getRelativeError()0.001>>> model = imputer.fit(df)>>> model.setInputCols(["a", "b"])ImputerModel...>>> model.getStrategy()'mean'>>> model.surrogateDF.show()+---+---+| a| b|+---+---+|3.0|4.0|+---+---+...>>> model.transform(df).show()+---+---+-----+-----+| a| b|out_a|out_b|+---+---+-----+-----+|1.0|NaN| 1.0| 4.0||2.0|NaN| 2.0| 4.0||NaN|3.0| 3.0| 3.0|...>>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).→˓show()+---+---+-----+-----+| a| b|out_a|out_b|+---+---+-----+-----+|1.0|NaN| 4.0| NaN|...>>> df1 = spark.createDataFrame([(1.0,), (2.0,), (float("nan"),), (4.0,), (5.0,)],→˓ ["a"])>>> imputer1 = Imputer(inputCol="a", outputCol="out_a")>>> model1 = imputer1.fit(df1)>>> model1.surrogateDF.show()+---+| a|+---+|3.0|+---+...>>> model1.transform(df1).show()+---+-----+| a|out_a|+---+-----+|1.0| 1.0||2.0| 2.0||NaN| 3.0|
(continues on next page)
390 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
...>>> imputer1.setStrategy("median").setMissingValue(1.0).fit(df1).transform(df1).→˓show()+---+-----+| a|out_a|+---+-----+|1.0| 4.0|...>>> df2 = spark.createDataFrame([(float("nan"),), (float("nan"),), (3.0,), (4.0,),→˓ (5.0,)],... ["b"])>>> imputer2 = Imputer(inputCol="b", outputCol="out_b")>>> model2 = imputer2.fit(df2)>>> model2.surrogateDF.show()+---+| b|+---+|4.0|+---+...>>> model2.transform(df2).show()+---+-----+| b|out_b|+---+-----+|NaN| 4.0||NaN| 4.0||3.0| 3.0|...>>> imputer2.setStrategy("median").setMissingValue(1.0).fit(df2).transform(df2).→˓show()+---+-----+| b|out_b|+---+-----+|NaN| NaN|...>>> imputerPath = temp_path + "/imputer">>> imputer.save(imputerPath)>>> loadedImputer = Imputer.load(imputerPath)>>> loadedImputer.getStrategy() == imputer.getStrategy()True>>> loadedImputer.getMissingValue()1.0>>> modelPath = temp_path + "/imputer-model">>> model.save(modelPath)>>> loadedModel = ImputerModel.load(modelPath)>>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_aTrue
3.3. ML 391
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getMissingValue() Gets the value of missingValue or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.getStrategy() Gets the value of strategy or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setMissingValue(value) Sets the value of missingValue.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, strategy, . . . ]) Sets params for this Imputer.setRelativeError(value) Sets the value of relativeError.
continues on next page
392 Chapter 3. API Reference
pyspark Documentation, Release master
Table 111 – continued from previous pagesetStrategy(value) Sets the value of strategy .write() Returns an MLWriter instance for this ML instance.
Attributes
inputColinputColsmissingValueoutputColoutputColsparams Returns all params ordered by name.relativeErrorstrategy
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
3.3. ML 393
pyspark Documentation, Release master
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getMissingValue()Gets the value of missingValue or its default value.
New in version 2.2.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getRelativeError()Gets the value of relativeError or its default value.
getStrategy()Gets the value of strategy or its default value.
New in version 2.2.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
394 Chapter 3. API Reference
pyspark Documentation, Release master
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setInputCols(value)Sets the value of inputCols.
New in version 2.2.0.
setMissingValue(value)Sets the value of missingValue.
New in version 2.2.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
setOutputCols(value)Sets the value of outputCols.
New in version 2.2.0.
setParams(self, \*, strategy="mean", missingValue=float("nan"), inputCols=None, outputCols=None,inputCol=None, outputCol=None, relativeError=0.001)
Sets params for this Imputer.
New in version 2.2.0.
setRelativeError(value)Sets the value of relativeError.
New in version 3.0.0.
setStrategy(value)Sets the value of strategy .
New in version 2.2.0.
write()Returns an MLWriter instance for this ML instance.
3.3. ML 395
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
missingValue = Param(parent='undefined', name='missingValue', doc='The placeholder for the missing values. All occurrences of missingValue will be imputed.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')
strategy = Param(parent='undefined', name='strategy', doc='strategy for imputation. If mean, then replace missing values using the mean value of the feature. If median, then replace missing values using the median value of the feature. If mode, then replace missing using the most frequent value of the feature.')
ImputerModel
class pyspark.ml.feature.ImputerModel(java_model=None)Model fitted by Imputer.
New in version 2.2.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getMissingValue() Gets the value of missingValue or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.
continues on next page
396 Chapter 3. API Reference
pyspark Documentation, Release master
Table 113 – continued from previous pagegetStrategy() Gets the value of strategy or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColinputColsmissingValueoutputColoutputColsparams Returns all params ordered by name.relativeErrorstrategysurrogateDF Returns a DataFrame containing inputCols and their
corresponding surrogates, which are used to replacethe missing values in the input DataFrame.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
3.3. ML 397
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getMissingValue()Gets the value of missingValue or its default value.
New in version 2.2.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getRelativeError()Gets the value of relativeError or its default value.
getStrategy()Gets the value of strategy or its default value.
New in version 2.2.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
398 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
missingValue = Param(parent='undefined', name='missingValue', doc='The placeholder for the missing values. All occurrences of missingValue will be imputed.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
3.3. ML 399
pyspark Documentation, Release master
relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')
strategy = Param(parent='undefined', name='strategy', doc='strategy for imputation. If mean, then replace missing values using the mean value of the feature. If median, then replace missing values using the median value of the feature. If mode, then replace missing using the most frequent value of the feature.')
surrogateDFReturns a DataFrame containing inputCols and their corresponding surrogates, which are used to replacethe missing values in the input DataFrame.
New in version 2.2.0.
IndexToString
class pyspark.ml.feature.IndexToString(*args, **kwargs)A pyspark.ml.base.Transformer that maps a column of indices back to a new column of correspond-ing string values. The index-string mapping is either from the ML attributes of the input column, or fromuser-supplied labels (which take precedence over ML attributes).
New in version 1.6.0.
See also:
StringIndexer for converting categorical values into category indices
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getLabels() Gets the value of labels or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.
continues on next page
400 Chapter 3. API Reference
pyspark Documentation, Release master
Table 115 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setLabels(value) Sets the value of labels.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this IndexToString.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputCollabelsoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
3.3. ML 401
pyspark Documentation, Release master
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getLabels()Gets the value of labels or its default value.
New in version 1.6.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setLabels(value)Sets the value of labels.
New in version 1.6.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inputCol=None, outputCol=None, labels=None)Sets params for this IndexToString.
New in version 1.6.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
402 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
labels = Param(parent='undefined', name='labels', doc='Optional array of labels specifying index-string mapping. If not provided or if empty, then metadata from inputCol is used instead.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Interaction
class pyspark.ml.feature.Interaction(*args, **kwargs)Implements the feature interaction transform. This transformer takes in Double and Vector type columns andoutputs a flattened vector of their feature interactions. To handle interaction, we first one-hot encode any nominalfeatures. Then, a vector of the feature cross-products is produced.
For example, given the input feature values Double(2) and Vector(3, 4), the output would be Vector(6, 8) if allinput features were numeric. If the first feature was instead nominal with four categories, the output would thenbe Vector(0, 0, 0, 0, 3, 4, 0, 0).
New in version 3.0.0.
Examples
>>> df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"])>>> interaction = Interaction()>>> interaction.setInputCols(["a", "b"])Interaction...>>> interaction.setOutputCol("ab")Interaction...>>> interaction.transform(df).show()+---+---+-----+| a| b| ab|+---+---+-----+|0.0|1.0|[0.0]||2.0|3.0|[6.0]|+---+---+-----+...>>> interactionPath = temp_path + "/interaction">>> interaction.save(interactionPath)
(continues on next page)
3.3. ML 403
pyspark Documentation, Release master
(continued from previous page)
>>> loadedInteraction = Interaction.load(interactionPath)>>> loadedInteraction.transform(df).head().ab == interaction.transform(df).head().→˓abTrue
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCols, outputCol]) Sets params for this Interaction.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
404 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
inputColsoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 405
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
setParams(self, \*, inputCols=None, outputCol=None)Sets params for this Interaction.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
406 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
MaxAbsScaler
class pyspark.ml.feature.MaxAbsScaler(*args, **kwargs)Rescale each feature individually to range [-1, 1] by dividing through the largest maximum absolute value ineach feature. It does not shift/center the data, and thus does not destroy any sparsity.
New in version 2.0.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)],→˓ ["a"])>>> maScaler = MaxAbsScaler(outputCol="scaled")>>> maScaler.setInputCol("a")MaxAbsScaler...>>> model = maScaler.fit(df)>>> model.setOutputCol("scaledOutput")MaxAbsScalerModel...>>> model.transform(df).show()+-----+------------+| a|scaledOutput|+-----+------------+|[1.0]| [0.5]||[2.0]| [1.0]|+-----+------------+...>>> scalerPath = temp_path + "/max-abs-scaler">>> maScaler.save(scalerPath)>>> loadedMAScaler = MaxAbsScaler.load(scalerPath)>>> loadedMAScaler.getInputCol() == maScaler.getInputCol()True>>> loadedMAScaler.getOutputCol() == maScaler.getOutputCol()True>>> modelPath = temp_path + "/max-abs-scaler-model">>> model.save(modelPath)>>> loadedModel = MaxAbsScalerModel.load(modelPath)>>> loadedModel.maxAbs == model.maxAbsTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
3.3. ML 407
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol]) Sets params for this MaxAbsScaler.write() Returns an MLWriter instance for this ML instance.
408 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
inputColoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
3.3. ML 409
pyspark Documentation, Release master
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inputCol=None, outputCol=None)Sets params for this MaxAbsScaler.
New in version 2.0.0.
410 Chapter 3. API Reference
pyspark Documentation, Release master
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
MaxAbsScalerModel
class pyspark.ml.feature.MaxAbsScalerModel(java_model=None)Model fitted by MaxAbsScaler.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.
continues on next page
3.3. ML 411
pyspark Documentation, Release master
Table 121 – continued from previous pagesave(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColmaxAbs Max Abs vector.outputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
412 Chapter 3. API Reference
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 413
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxAbsMax Abs vector.
New in version 2.0.0.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
MinHashLSH
class pyspark.ml.feature.MinHashLSH(*args, **kwargs)LSH class for Jaccard distance. The input can be dense or sparse vectors, but it is more efficient if it is sparse.For example, Vectors.sparse(10, [(2, 1.0), (3, 1.0), (5, 1.0)]) means there are 10 elements in the space. Thisset contains elements 2, 3, and 5. Also, any input vector must have at least 1 non-zero index, and all non-zerovalues are treated as binary “1” values.
New in version 2.2.0.
Notes
See Wikipedia on MinHash
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.sql.functions import col>>> data = [(0, Vectors.sparse(6, [0, 1, 2], [1.0, 1.0, 1.0]),),... (1, Vectors.sparse(6, [2, 3, 4], [1.0, 1.0, 1.0]),),... (2, Vectors.sparse(6, [0, 2, 4], [1.0, 1.0, 1.0]),)]>>> df = spark.createDataFrame(data, ["id", "features"])>>> mh = MinHashLSH()>>> mh.setInputCol("features")MinHashLSH...>>> mh.setOutputCol("hashes")MinHashLSH...>>> mh.setSeed(12345)MinHashLSH...>>> model = mh.fit(df)>>> model.setInputCol("features")MinHashLSHModel...>>> model.transform(df).head()Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}),→˓hashes=[DenseVector([6179668...>>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),),... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)]>>> df2 = spark.createDataFrame(data2, ["id", "features"])>>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0])
(continues on next page)
414 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> model.approxNearestNeighbors(df2, key, 1).collect()[Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}),→˓hashes=[DenseVector([6179668...>>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select(... col("datasetA.id").alias("idA"),... col("datasetB.id").alias("idB"),... col("JaccardDistance")).show()+---+---+---------------+|idA|idB|JaccardDistance|+---+---+---------------+| 0| 5| 0.5|| 1| 4| 0.5|+---+---+---------------+...>>> mhPath = temp_path + "/mh">>> mh.save(mhPath)>>> mh2 = MinHashLSH.load(mhPath)>>> mh2.getOutputCol() == mh.getOutputCol()True>>> modelPath = temp_path + "/mh-model">>> model.save(modelPath)>>> model2 = MinHashLSHModel.load(modelPath)>>> model.transform(df).head().hashes == model2.transform(df).head().hashesTrue
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.
continues on next page
3.3. ML 415
pyspark Documentation, Release master
Table 123 – continued from previous pagegetParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setNumHashTables(value) Sets the value of numHashTables.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this MinHashLSH.setSeed(value) Sets the value of seed.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColnumHashTablesoutputColparams Returns all params ordered by name.seed
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
416 Chapter 3. API Reference
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getNumHashTables()Gets the value of numHashTables or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
3.3. ML 417
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setNumHashTables(value)Sets the value of numHashTables.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inputCol=None, outputCol=None, seed=None, numHashTables=1)Sets params for this MinHashLSH.
New in version 2.2.0.
setSeed(value)Sets the value of seed.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
418 Chapter 3. API Reference
pyspark Documentation, Release master
MinHashLSHModel
class pyspark.ml.feature.MinHashLSHModel(java_model=None)Model produced by MinHashLSH , where where multiple hash functions are stored. Each hash function ispicked from the following family of hash functions, where 𝑎𝑖 and 𝑏𝑖 are randomly chosen integers less thanprime: ℎ𝑖(𝑥) = ((𝑥 ·𝑎𝑖+ 𝑏𝑖) mod 𝑝𝑟𝑖𝑚𝑒) This hash family is approximately min-wise independent accordingto the reference.
New in version 2.2.0.
Notes
See Tom Bohman, Colin Cooper, and Alan Frieze. “Min-wise independent linear permutations.” ElectronicJournal of Combinatorics 7 (2000): R26.
Methods
approxNearestNeighbors(dataset, key, . . . [,. . . ])
Given a large dataset and an item, approximately findat most k items which have the closest distance to theitem.
approxSimilarityJoin(datasetA, datasetB,. . . )
Join two datasets to approximately find all pairs ofrows whose distance are smaller than the threshold.
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).continues on next page
3.3. ML 419
pyspark Documentation, Release master
Table 125 – continued from previous pageread() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColnumHashTablesoutputColparams Returns all params ordered by name.
Methods Documentation
approxNearestNeighbors(dataset, key, numNearestNeighbors, distCol='distCol')Given a large dataset and an item, approximately find at most k items which have the closest distance tothe item. If the outputCol is missing, the method will transform the data; if the outputCol exists, itwill use that. This allows caching of the transformed data when necessary.
Parameters
dataset [pyspark.sql.DataFrame] The dataset to search for nearest neighbors of thekey.
key [pyspark.ml.linalg.Vector] Feature vector representing the item to search for.
numNearestNeighbors [int] The maximum number of nearest neighbors.
distCol [str] Output column for storing the distance between each result row and the key.Use “distCol” as default value if it’s not specified.
Returns
pyspark.sql.DataFrame A dataset containing at most k items closest to the key. Acolumn “distCol” is added to show the distance between each row and the key.
Notes
This method is experimental and will likely change behavior in the next release.
approxSimilarityJoin(datasetA, datasetB, threshold, distCol='distCol')Join two datasets to approximately find all pairs of rows whose distance are smaller than the threshold. Ifthe outputCol is missing, the method will transform the data; if the outputCol exists, it will use that.This allows caching of the transformed data when necessary.
Parameters
datasetA [pyspark.sql.DataFrame] One of the datasets to join.
datasetB [pyspark.sql.DataFrame] Another dataset to join.
420 Chapter 3. API Reference
pyspark Documentation, Release master
threshold [float] The threshold for the distance of row pairs.
distCol [str, optional] Output column for storing the distance between each pair of rows.Use “distCol” as default value if it’s not specified.
Returns
pyspark.sql.DataFrame A joined dataset containing pairs of rows. The original rowsare in columns “datasetA” and “datasetB”, and a column “distCol” is added to show thedistance between each pair.
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getNumHashTables()Gets the value of numHashTables or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 421
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
422 Chapter 3. API Reference
pyspark Documentation, Release master
MinMaxScaler
class pyspark.ml.feature.MinMaxScaler(*args, **kwargs)Rescale each feature individually to a common range [min, max] linearly using column summary statistics,which is also known as min-max normalization or Rescaling. The rescaled value for feature E is calculated as,
Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min
For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min)
New in version 1.6.0.
Notes
Since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVectoreven for sparse input.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)],→˓ ["a"])>>> mmScaler = MinMaxScaler(outputCol="scaled")>>> mmScaler.setInputCol("a")MinMaxScaler...>>> model = mmScaler.fit(df)>>> model.setOutputCol("scaledOutput")MinMaxScalerModel...>>> model.originalMinDenseVector([0.0])>>> model.originalMaxDenseVector([2.0])>>> model.transform(df).show()+-----+------------+| a|scaledOutput|+-----+------------+|[0.0]| [0.0]||[2.0]| [1.0]|+-----+------------+...>>> minMaxScalerPath = temp_path + "/min-max-scaler">>> mmScaler.save(minMaxScalerPath)>>> loadedMMScaler = MinMaxScaler.load(minMaxScalerPath)>>> loadedMMScaler.getMin() == mmScaler.getMin()True>>> loadedMMScaler.getMax() == mmScaler.getMax()True>>> modelPath = temp_path + "/min-max-scaler-model">>> model.save(modelPath)>>> loadedModel = MinMaxScalerModel.load(modelPath)>>> loadedModel.originalMin == model.originalMinTrue>>> loadedModel.originalMax == model.originalMaxTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
3.3. ML 423
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getMax() Gets the value of max or its default value.getMin() Gets the value of min or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMax(value) Sets the value of max.setMin(value) Sets the value of min.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, min, max, inputCol, . . . ]) Sets params for this MinMaxScaler.write() Returns an MLWriter instance for this ML instance.
424 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
inputColmaxminoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
3.3. ML 425
pyspark Documentation, Release master
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getMax()Gets the value of max or its default value.
New in version 1.6.0.
getMin()Gets the value of min or its default value.
New in version 1.6.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
426 Chapter 3. API Reference
pyspark Documentation, Release master
setInputCol(value)Sets the value of inputCol.
setMax(value)Sets the value of max.
New in version 1.6.0.
setMin(value)Sets the value of min.
New in version 1.6.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, min=0.0, max=1.0, inputCol=None, outputCol=None)Sets params for this MinMaxScaler.
New in version 1.6.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
max = Param(parent='undefined', name='max', doc='Upper bound of the output feature range')
min = Param(parent='undefined', name='min', doc='Lower bound of the output feature range')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
MinMaxScalerModel
class pyspark.ml.feature.MinMaxScalerModel(java_model=None)Model fitted by MinMaxScaler.
New in version 1.6.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
3.3. ML 427
pyspark Documentation, Release master
Table 129 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getMax() Gets the value of max or its default value.getMin() Gets the value of min or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMax(value) Sets the value of max.setMin(value) Sets the value of min.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColmaxminoriginalMax Max value for each original column during fitting.originalMin Min value for each original column during fitting.outputColparams Returns all params ordered by name.
428 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getMax()Gets the value of max or its default value.
New in version 1.6.0.
getMin()Gets the value of min or its default value.
New in version 1.6.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 429
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setMax(value)Sets the value of max.
New in version 3.0.0.
setMin(value)Sets the value of min.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
430 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
max = Param(parent='undefined', name='max', doc='Upper bound of the output feature range')
min = Param(parent='undefined', name='min', doc='Lower bound of the output feature range')
originalMaxMax value for each original column during fitting.
New in version 2.0.0.
originalMinMin value for each original column during fitting.
New in version 2.0.0.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
NGram
class pyspark.ml.feature.NGram(*args, **kwargs)A feature transformer that converts the input array of strings into an array of n-grams. Null values in the inputarray are ignored. It returns an array of n-grams where each n-gram is represented by a space-separated string ofwords. When the input is empty, an empty array is returned. When the input array length is less than n (numberof elements per n-gram), no n-grams are returned.
New in version 1.5.0.
Examples
>>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])>>> ngram = NGram(n=2)>>> ngram.setInputCol("inputTokens")NGram...>>> ngram.setOutputCol("nGrams")NGram...>>> ngram.transform(df).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b', 'b c', 'c d', 'd e'])>>> # Change n-gram length>>> ngram.setParams(n=4).transform(df).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b c d', 'b c d e'])>>> # Temporarily modify output column.>>> ngram.transform(df, {ngram.outputCol: "output"}).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], output=['a b c d', 'b c d e'])>>> ngram.transform(df).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b c d', 'b c d e'])>>> # Must use keyword arguments to specify params.>>> ngram.setParams("text")Traceback (most recent call last):
...TypeError: Method setParams forces keyword arguments.>>> ngramPath = temp_path + "/ngram"
(continues on next page)
3.3. ML 431
pyspark Documentation, Release master
(continued from previous page)
>>> ngram.save(ngramPath)>>> loadedNGram = NGram.load(ngramPath)>>> loadedNGram.getN() == ngram.getN()True>>> loadedNGram.transform(df).take(1) == ngram.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getN () Gets the value of n or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setN (value) Sets the value of n.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, n, inputCol, outputCol]) Sets params for this NGram.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
432 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
inputColnoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getN()Gets the value of n or its default value.
New in version 1.5.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
3.3. ML 433
pyspark Documentation, Release master
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setN(value)Sets the value of n.
New in version 1.5.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, n=2, inputCol=None, outputCol=None)Sets params for this NGram.
New in version 1.5.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
434 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
n = Param(parent='undefined', name='n', doc='number of elements per n-gram (>=1)')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Normalizer
class pyspark.ml.feature.Normalizer(*args, **kwargs)
Normalize a vector to have unit norm using the given p-norm.
New in version 1.4.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0})>>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense",→˓"sparse"])>>> normalizer = Normalizer(p=2.0)>>> normalizer.setInputCol("dense")Normalizer...>>> normalizer.setOutputCol("features")Normalizer...>>> normalizer.transform(df).head().featuresDenseVector([0.6, -0.8])>>> normalizer.setParams(inputCol="sparse", outputCol="freqs").transform(df).→˓head().freqsSparseVector(4, {1: 0.8, 3: 0.6})>>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.→˓outputCol: "vector"}>>> normalizer.transform(df, params).head().vectorDenseVector([0.4286, -0.5714])>>> normalizerPath = temp_path + "/normalizer">>> normalizer.save(normalizerPath)>>> loadedNormalizer = Normalizer.load(normalizerPath)>>> loadedNormalizer.getP() == normalizer.getP()True>>> loadedNormalizer.transform(df).take(1) == normalizer.transform(df).take(1)True
3.3. ML 435
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getP() Gets the value of p or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setP(value) Sets the value of p.setParams(self, \*[, p, inputCol, outputCol]) Sets params for this Normalizer.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
436 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
inputColoutputColpparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getP()Gets the value of p or its default value.
New in version 1.4.0.
3.3. ML 437
pyspark Documentation, Release master
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
setP(value)Sets the value of p.
New in version 1.4.0.
setParams(self, \*, p=2.0, inputCol=None, outputCol=None)Sets params for this Normalizer.
New in version 1.4.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
438 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
p = Param(parent='undefined', name='p', doc='the p norm value.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
OneHotEncoder
class pyspark.ml.feature.OneHotEncoder(*args, **kwargs)A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a singleone-value per row that indicates the input category index. For example with 5 categories, an input value of 2.0would map to an output vector of [0.0, 0.0, 1.0, 0.0]. The last category is not included by default (configurablevia dropLast), because it makes the vector entries sum up to one, and hence linearly dependent. So an inputvalue of 4.0 maps to [0.0, 0.0, 0.0, 0.0].
When handleInvalid is configured to ‘keep’, an extra “category” indicating invalid values is added as lastcategory. So when dropLast is true, invalid values are encoded as all-zeros vector.
New in version 2.3.0.
See also:
StringIndexer for converting categorical values into category indices
Notes
This is different from scikit-learn’s OneHotEncoder, which keeps all categories. The output vectors are sparse.
When encoding multi-column by using inputCols and outputCols params, input/output cols come inpairs, specified by the order in the arrays, and each pair is treated independently.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])>>> ohe = OneHotEncoder()>>> ohe.setInputCols(["input"])OneHotEncoder...>>> ohe.setOutputCols(["output"])OneHotEncoder...>>> model = ohe.fit(df)>>> model.setOutputCols(["output"])OneHotEncoderModel...>>> model.getHandleInvalid()'error'>>> model.transform(df).head().outputSparseVector(2, {0: 1.0})>>> single_col_ohe = OneHotEncoder(inputCol="input", outputCol="output")>>> single_col_model = single_col_ohe.fit(df)>>> single_col_model.transform(df).head().output
(continues on next page)
3.3. ML 439
pyspark Documentation, Release master
(continued from previous page)
SparseVector(2, {0: 1.0})>>> ohePath = temp_path + "/ohe">>> ohe.save(ohePath)>>> loadedOHE = OneHotEncoder.load(ohePath)>>> loadedOHE.getInputCols() == ohe.getInputCols()True>>> modelPath = temp_path + "/ohe-model">>> model.save(modelPath)>>> loadedModel = OneHotEncoderModel.load(modelPath)>>> loadedModel.categorySizes == model.categorySizesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getDropLast() Gets the value of dropLast or its default value.getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.
continues on next page
440 Chapter 3. API Reference
pyspark Documentation, Release master
Table 135 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDropLast(value) Sets the value of dropLast.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, inputCols, outputCols, . . . ]) Sets params for this OneHotEncoder.write() Returns an MLWriter instance for this ML instance.
Attributes
dropLasthandleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra
3.3. ML 441
pyspark Documentation, Release master
values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getDropLast()Gets the value of dropLast or its default value.
New in version 2.3.0.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
442 Chapter 3. API Reference
pyspark Documentation, Release master
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDropLast(value)Sets the value of dropLast.
New in version 2.3.0.
setHandleInvalid(value)Sets the value of handleInvalid.
New in version 3.0.0.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
setParams(self, \*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, in-putCol=None, outputCol=None)
Sets params for this OneHotEncoder.
New in version 2.3.0.
3.3. ML 443
pyspark Documentation, Release master
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
dropLast = Param(parent='undefined', name='dropLast', doc='whether to drop the last category')
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data during transform(). Options are 'keep' (invalid data presented as an extra categorical feature) or error (throw an error). Note that this Param is only used during transform; during fitting, invalid data will result in an error.")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
OneHotEncoderModel
class pyspark.ml.feature.OneHotEncoderModel(java_model=None)Model fitted by OneHotEncoder.
New in version 2.3.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getDropLast() Gets the value of dropLast or its default value.getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.
continues on next page
444 Chapter 3. API Reference
pyspark Documentation, Release master
Table 137 – continued from previous pagegetParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDropLast(value) Sets the value of dropLast.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
categorySizes Original number of categories for each feature beingencoded.
dropLasthandleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
3.3. ML 445
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getDropLast()Gets the value of dropLast or its default value.
New in version 2.3.0.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
446 Chapter 3. API Reference
pyspark Documentation, Release master
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDropLast(value)Sets the value of dropLast.
New in version 3.0.0.
setHandleInvalid(value)Sets the value of handleInvalid.
New in version 3.0.0.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 447
pyspark Documentation, Release master
Attributes Documentation
categorySizesOriginal number of categories for each feature being encoded. The array contains one value for each inputcolumn, in order.
New in version 2.3.0.
dropLast = Param(parent='undefined', name='dropLast', doc='whether to drop the last category')
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data during transform(). Options are 'keep' (invalid data presented as an extra categorical feature) or error (throw an error). Note that this Param is only used during transform; during fitting, invalid data will result in an error.")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
PCA
class pyspark.ml.feature.PCA(*args, **kwargs)PCA trains a model to project vectors to a lower dimensional space of the top k principal components.
New in version 1.5.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]>>> df = spark.createDataFrame(data,["features"])>>> pca = PCA(k=2, inputCol="features")>>> pca.setOutputCol("pca_features")PCA...>>> model = pca.fit(df)>>> model.getK()2>>> model.setOutputCol("output")PCAModel...>>> model.transform(df).collect()[0].outputDenseVector([1.648..., -4.013...])>>> model.explainedVarianceDenseVector([0.794..., 0.205...])>>> pcaPath = temp_path + "/pca">>> pca.save(pcaPath)>>> loadedPca = PCA.load(pcaPath)>>> loadedPca.getK() == pca.getK()True>>> modelPath = temp_path + "/pca-model">>> model.save(modelPath)>>> loadedModel = PCAModel.load(modelPath)
(continues on next page)
448 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> loadedModel.pc == model.pcTrue>>> loadedModel.explainedVariance == model.explainedVarianceTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getK() Gets the value of k or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setK(value) Sets the value of k.setOutputCol(value) Sets the value of outputCol.
continues on next page
3.3. ML 449
pyspark Documentation, Release master
Table 139 – continued from previous pagesetParams(self, \*[, k, inputCol, outputCol]) Set params for this PCA.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColkoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
450 Chapter 3. API Reference
pyspark Documentation, Release master
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getK()Gets the value of k or its default value.
New in version 1.5.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
3.3. ML 451
pyspark Documentation, Release master
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setK(value)Sets the value of k.
New in version 1.5.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, k=None, inputCol=None, outputCol=None)Set params for this PCA.
New in version 1.5.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
k = Param(parent='undefined', name='k', doc='the number of principal components')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
PCAModel
class pyspark.ml.feature.PCAModel(java_model=None)Model fitted by PCA. Transforms vectors to a lower dimensional space.
New in version 1.5.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
452 Chapter 3. API Reference
pyspark Documentation, Release master
Table 141 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getK() Gets the value of k or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
explainedVariance Returns a vector of proportions of variance explainedby each principal component.
inputColkoutputColparams Returns all params ordered by name.pc Returns a principal components Matrix.
3.3. ML 453
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getK()Gets the value of k or its default value.
New in version 1.5.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
454 Chapter 3. API Reference
pyspark Documentation, Release master
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
explainedVarianceReturns a vector of proportions of variance explained by each principal component.
New in version 2.0.0.
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
k = Param(parent='undefined', name='k', doc='the number of principal components')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
pcReturns a principal components Matrix. Each column is one principal component.
New in version 2.0.0.
3.3. ML 455
pyspark Documentation, Release master
PolynomialExpansion
class pyspark.ml.feature.PolynomialExpansion(*args, **kwargs)Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, “In mathe-matics, an expansion of a product of sums expresses it as a sum of products by using the fact that multiplicationdistributes over addition”. Take a 2-variable feature vector as an example: (x, y), if we want to expand it withdegree 2, then we get (x, x * x, y, x * y, y * y).
New in version 1.4.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])>>> px = PolynomialExpansion(degree=2)>>> px.setInputCol("dense")PolynomialExpansion...>>> px.setOutputCol("expanded")PolynomialExpansion...>>> px.transform(df).head().expandedDenseVector([0.5, 0.25, 2.0, 1.0, 4.0])>>> px.setParams(outputCol="test").transform(df).head().testDenseVector([0.5, 0.25, 2.0, 1.0, 4.0])>>> polyExpansionPath = temp_path + "/poly-expansion">>> px.save(polyExpansionPath)>>> loadedPx = PolynomialExpansion.load(polyExpansionPath)>>> loadedPx.getDegree() == px.getDegree()True>>> loadedPx.transform(df).take(1) == px.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getDegree() Gets the value of degree or its default value.getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.continues on next page
456 Chapter 3. API Reference
pyspark Documentation, Release master
Table 143 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDegree(value) Sets the value of degree.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, degree, inputCol, . . . ]) Sets params for this PolynomialExpansion.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
degreeinputColoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
3.3. ML 457
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getDegree()Gets the value of degree or its default value.
New in version 1.4.0.
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDegree(value)Sets the value of degree.
New in version 1.4.0.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
458 Chapter 3. API Reference
pyspark Documentation, Release master
setParams(self, \*, degree=2, inputCol=None, outputCol=None)Sets params for this PolynomialExpansion.
New in version 1.4.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
degree = Param(parent='undefined', name='degree', doc='the polynomial degree to expand (>= 1)')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
QuantileDiscretizer
class pyspark.ml.feature.QuantileDiscretizer(*args, **kwargs)QuantileDiscretizer takes a column with continuous features and outputs a column with binned cat-egorical features. The number of bins can be set using the numBuckets parameter. It is possible that thenumber of buckets used will be less than this value, for example, if there are too few distinct values of theinput to create enough distinct quantiles. Since 3.0.0, QuantileDiscretizer can map multiple columnsat once by setting the inputCols parameter. If both of the inputCol and inputCols parameters are set,an Exception will be thrown. To specify the number of buckets for each column, the numBucketsArrayparameter can be set, or if the number of buckets should be the same across columns, numBuckets can be setas a convenience.
New in version 2.0.0.
3.3. ML 459
pyspark Documentation, Release master
Notes
NaN handling: Note also that QuantileDiscretizer will raise an error when it finds NaN values inthe dataset, but the user can also choose to either keep or remove NaN values within the dataset by settinghandleInvalid parameter. If the user chooses to keep NaN values, they will be handled specially and placedinto their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], butNaNs will be counted in a special bucket[4].
Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation forapproxQuantile() for a detailed description). The precision of the approximation can be controlled withthe relativeError parameter. The lower and upper bin bounds will be -Infinity and +Infinity, covering allreal values.
Examples
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]>>> df1 = spark.createDataFrame(values, ["values"])>>> qds1 = QuantileDiscretizer(inputCol="values", outputCol="buckets")>>> qds1.setNumBuckets(2)QuantileDiscretizer...>>> qds1.setRelativeError(0.01)QuantileDiscretizer...>>> qds1.setHandleInvalid("error")QuantileDiscretizer...>>> qds1.getRelativeError()0.01>>> bucketizer = qds1.fit(df1)>>> qds1.setHandleInvalid("keep").fit(df1).transform(df1).count()6>>> qds1.setHandleInvalid("skip").fit(df1).transform(df1).count()4>>> splits = bucketizer.getSplits()>>> splits[0]-inf>>> print("%2.1f" % round(splits[1], 1))0.4>>> bucketed = bucketizer.transform(df1).head()>>> bucketed.buckets0.0>>> quantileDiscretizerPath = temp_path + "/quantile-discretizer">>> qds1.save(quantileDiscretizerPath)>>> loadedQds = QuantileDiscretizer.load(quantileDiscretizerPath)>>> loadedQds.getNumBuckets() == qds1.getNumBuckets()True>>> inputs = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, 1.5),... (float("nan"), float("nan")), (float("nan"), float("nan"))]>>> df2 = spark.createDataFrame(inputs, ["input1", "input2"])>>> qds2 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error",→˓numBuckets=2,... inputCols=["input1", "input2"], outputCols=["output1", "output2"])>>> qds2.getRelativeError()0.01>>> qds2.setHandleInvalid("keep").fit(df2).transform(df2).show()+------+------+-------+-------+|input1|input2|output1|output2|+------+------+-------+-------+
(continues on next page)
460 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
| 0.1| 0.0| 0.0| 0.0|| 0.4| 1.0| 1.0| 1.0|| 1.2| 1.3| 1.0| 1.0|| 1.5| 1.5| 1.0| 1.0|| NaN| NaN| 2.0| 2.0|| NaN| NaN| 2.0| 2.0|+------+------+-------+-------+...>>> qds3 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error",... numBucketsArray=[5, 10], inputCols=["input1", "input2"],... outputCols=["output1", "output2"])>>> qds3.setHandleInvalid("skip").fit(df2).transform(df2).show()+------+------+-------+-------+|input1|input2|output1|output2|+------+------+-------+-------+| 0.1| 0.0| 1.0| 1.0|| 0.4| 1.0| 2.0| 2.0|| 1.2| 1.3| 3.0| 3.0|| 1.5| 1.5| 4.0| 4.0|+------+------+-------+-------+...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getNumBuckets() Gets the value of numBuckets or its default value.getNumBucketsArray() Gets the value of numBucketsArray or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.continues on next page
3.3. ML 461
pyspark Documentation, Release master
Table 145 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setNumBuckets(value) Sets the value of numBuckets.setNumBucketsArray(value) Sets the value of numBucketsArray .setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, numBuckets, inputCol, . . . ]) Set the params for the QuantileDiscretizersetRelativeError(value) Sets the value of relativeError.write() Returns an MLWriter instance for this ML instance.
Attributes
handleInvalidinputColinputColsnumBucketsnumBucketsArrayoutputColoutputColsparams Returns all params ordered by name.relativeError
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
462 Chapter 3. API Reference
pyspark Documentation, Release master
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
3.3. ML 463
pyspark Documentation, Release master
getInputCols()Gets the value of inputCols or its default value.
getNumBuckets()Gets the value of numBuckets or its default value.
New in version 2.0.0.
getNumBucketsArray()Gets the value of numBucketsArray or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getRelativeError()Gets the value of relativeError or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setHandleInvalid(value)Sets the value of handleInvalid.
setInputCol(value)Sets the value of inputCol.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
464 Chapter 3. API Reference
pyspark Documentation, Release master
setNumBuckets(value)Sets the value of numBuckets.
New in version 2.0.0.
setNumBucketsArray(value)Sets the value of numBucketsArray .
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
setParams(self, \*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, handleIn-valid="error", numBucketsArray=None, inputCols=None, outputCols=None)
Set the params for the QuantileDiscretizer
New in version 2.0.0.
setRelativeError(value)Sets the value of relativeError.
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries. Options are skip (filter out rows with invalid values), error (throw an error), or keep (keep invalid values in a special additional bucket). Note that in the multiple columns case, the invalid handling is applied to all columns. That said for 'error' it will throw an error if any invalids are found in any columns, for 'skip' it will skip rows with any invalids in any columns, etc.")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
numBuckets = Param(parent='undefined', name='numBuckets', doc='Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must be >= 2.')
numBucketsArray = Param(parent='undefined', name='numBucketsArray', doc='Array of number of buckets (quantiles, or categories) into which data points are grouped. This is for multiple columns input. If transforming multiple columns and numBucketsArray is not set, but numBuckets is set, then numBuckets will be applied across all columns.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')
3.3. ML 465
pyspark Documentation, Release master
RobustScaler
class pyspark.ml.feature.RobustScaler(*args, **kwargs)RobustScaler removes the median and scales the data according to the quantile range. The quantile range is bydefault IQR (Interquartile Range, quantile range between the 1st quartile = 25th quantile and the 3rd quartile =75th quantile) but can be configured. Centering and scaling happen independently on each feature by computingthe relevant statistics on the samples in the training set. Median and quantile range are then stored to be usedon later data using the transform method. Note that NaN values are ignored in the computation of medians andranges.
New in version 3.0.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> data = [(0, Vectors.dense([0.0, 0.0]),),... (1, Vectors.dense([1.0, -1.0]),),... (2, Vectors.dense([2.0, -2.0]),),... (3, Vectors.dense([3.0, -3.0]),),... (4, Vectors.dense([4.0, -4.0]),),]>>> df = spark.createDataFrame(data, ["id", "features"])>>> scaler = RobustScaler()>>> scaler.setInputCol("features")RobustScaler...>>> scaler.setOutputCol("scaled")RobustScaler...>>> model = scaler.fit(df)>>> model.setOutputCol("output")RobustScalerModel...>>> model.medianDenseVector([2.0, -2.0])>>> model.rangeDenseVector([2.0, 2.0])>>> model.transform(df).collect()[1].outputDenseVector([0.5, -0.5])>>> scalerPath = temp_path + "/robust-scaler">>> scaler.save(scalerPath)>>> loadedScaler = RobustScaler.load(scalerPath)>>> loadedScaler.getWithCentering() == scaler.getWithCentering()True>>> loadedScaler.getWithScaling() == scaler.getWithScaling()True>>> modelPath = temp_path + "/robust-scaler-model">>> model.save(modelPath)>>> loadedModel = RobustScalerModel.load(modelPath)>>> loadedModel.median == model.medianTrue>>> loadedModel.range == model.rangeTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
466 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getLower() Gets the value of lower or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.getUpper() Gets the value of upper or its default value.getWithCentering() Gets the value of withCentering or its default value.getWithScaling() Gets the value of withScaling or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setLower(value) Sets the value of lower.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, lower, upper, . . . ]) Sets params for this RobustScaler.setRelativeError(value) Sets the value of relativeError.setUpper(value) Sets the value of upper.setWithCentering(value) Sets the value of withCentering.setWithScaling(value) Sets the value of withScaling.
continues on next page
3.3. ML 467
pyspark Documentation, Release master
Table 147 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.
Attributes
inputColloweroutputColparams Returns all params ordered by name.relativeErrorupperwithCenteringwithScaling
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
468 Chapter 3. API Reference
pyspark Documentation, Release master
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getLower()Gets the value of lower or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getRelativeError()Gets the value of relativeError or its default value.
getUpper()Gets the value of upper or its default value.
New in version 3.0.0.
getWithCentering()Gets the value of withCentering or its default value.
New in version 3.0.0.
getWithScaling()Gets the value of withScaling or its default value.
New in version 3.0.0.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 469
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setLower(value)Sets the value of lower.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
setParams(self, \*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, input-Col=None, outputCol=None, relativeError=0.001)
Sets params for this RobustScaler.
New in version 3.0.0.
setRelativeError(value)Sets the value of relativeError.
New in version 3.0.0.
setUpper(value)Sets the value of upper.
New in version 3.0.0.
setWithCentering(value)Sets the value of withCentering.
New in version 3.0.0.
setWithScaling(value)Sets the value of withScaling.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
470 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
lower = Param(parent='undefined', name='lower', doc='Lower quantile to calculate quantile range')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')
upper = Param(parent='undefined', name='upper', doc='Upper quantile to calculate quantile range')
withCentering = Param(parent='undefined', name='withCentering', doc='Whether to center data with median')
withScaling = Param(parent='undefined', name='withScaling', doc='Whether to scale the data to quantile range')
RobustScalerModel
class pyspark.ml.feature.RobustScalerModel(java_model=None)Model fitted by RobustScaler.
New in version 3.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getLower() Gets the value of lower or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.getUpper() Gets the value of upper or its default value.getWithCentering() Gets the value of withCentering or its default value.getWithScaling() Gets the value of withScaling or its default value.
continues on next page
3.3. ML 471
pyspark Documentation, Release master
Table 149 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputCollowermedian Median of the RobustScalerModel.outputColparams Returns all params ordered by name.range Quantile range of the RobustScalerModel.relativeErrorupperwithCenteringwithScaling
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
472 Chapter 3. API Reference
pyspark Documentation, Release master
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getLower()Gets the value of lower or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getRelativeError()Gets the value of relativeError or its default value.
getUpper()Gets the value of upper or its default value.
New in version 3.0.0.
getWithCentering()Gets the value of withCentering or its default value.
New in version 3.0.0.
getWithScaling()Gets the value of withScaling or its default value.
New in version 3.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
3.3. ML 473
pyspark Documentation, Release master
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
lower = Param(parent='undefined', name='lower', doc='Lower quantile to calculate quantile range')
medianMedian of the RobustScalerModel.
New in version 3.0.0.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
rangeQuantile range of the RobustScalerModel.
New in version 3.0.0.
relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')
upper = Param(parent='undefined', name='upper', doc='Upper quantile to calculate quantile range')
474 Chapter 3. API Reference
pyspark Documentation, Release master
withCentering = Param(parent='undefined', name='withCentering', doc='Whether to center data with median')
withScaling = Param(parent='undefined', name='withScaling', doc='Whether to scale the data to quantile range')
RegexTokenizer
class pyspark.ml.feature.RegexTokenizer(*args, **kwargs)A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to splitthe text (default) or repeatedly matching the regex (if gaps is false). Optional parameters also allow filteringtokens using a minimal length. It returns an array of strings that can be empty.
New in version 1.4.0.
Examples
>>> df = spark.createDataFrame([("A B c",)], ["text"])>>> reTokenizer = RegexTokenizer()>>> reTokenizer.setInputCol("text")RegexTokenizer...>>> reTokenizer.setOutputCol("words")RegexTokenizer...>>> reTokenizer.transform(df).head()Row(text='A B c', words=['a', 'b', 'c'])>>> # Change a parameter.>>> reTokenizer.setParams(outputCol="tokens").transform(df).head()Row(text='A B c', tokens=['a', 'b', 'c'])>>> # Temporarily modify a parameter.>>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head()Row(text='A B c', words=['a', 'b', 'c'])>>> reTokenizer.transform(df).head()Row(text='A B c', tokens=['a', 'b', 'c'])>>> # Must use keyword arguments to specify params.>>> reTokenizer.setParams("text")Traceback (most recent call last):
...TypeError: Method setParams forces keyword arguments.>>> regexTokenizerPath = temp_path + "/regex-tokenizer">>> reTokenizer.save(regexTokenizerPath)>>> loadedReTokenizer = RegexTokenizer.load(regexTokenizerPath)>>> loadedReTokenizer.getMinTokenLength() == reTokenizer.getMinTokenLength()True>>> loadedReTokenizer.getGaps() == reTokenizer.getGaps()True>>> loadedReTokenizer.transform(df).take(1) == reTokenizer.transform(df).take(1)True
3.3. ML 475
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getGaps() Gets the value of gaps or its default value.getInputCol() Gets the value of inputCol or its default value.getMinTokenLength() Gets the value of minTokenLength or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPattern() Gets the value of pattern or its default value.getToLowercase() Gets the value of toLowercase or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setGaps(value) Sets the value of gaps.setInputCol(value) Sets the value of inputCol.setMinTokenLength(value) Sets the value of minTokenLength.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minTokenLength, gaps, . . . ]) Sets params for this RegexTokenizer.setPattern(value) Sets the value of pattern.setToLowercase(value) Sets the value of toLowercase.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
476 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
gapsinputColminTokenLengthoutputColparams Returns all params ordered by name.patterntoLowercase
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getGaps()Gets the value of gaps or its default value.
New in version 1.4.0.
getInputCol()Gets the value of inputCol or its default value.
getMinTokenLength()Gets the value of minTokenLength or its default value.
New in version 1.4.0.
3.3. ML 477
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getPattern()Gets the value of pattern or its default value.
New in version 1.4.0.
getToLowercase()Gets the value of toLowercase or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setGaps(value)Sets the value of gaps.
New in version 1.4.0.
setInputCol(value)Sets the value of inputCol.
setMinTokenLength(value)Sets the value of minTokenLength.
New in version 1.4.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, minTokenLength=1, gaps=True, pattern="\s+", inputCol=None, outputCol=None,toLowercase=True)
Sets params for this RegexTokenizer.
New in version 1.4.0.
478 Chapter 3. API Reference
pyspark Documentation, Release master
setPattern(value)Sets the value of pattern.
New in version 1.4.0.
setToLowercase(value)Sets the value of toLowercase.
New in version 2.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
gaps = Param(parent='undefined', name='gaps', doc='whether regex splits on gaps (True) or matches tokens (False)')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
minTokenLength = Param(parent='undefined', name='minTokenLength', doc='minimum token length (>= 0)')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
pattern = Param(parent='undefined', name='pattern', doc='regex pattern (Java dialect) used for tokenizing')
toLowercase = Param(parent='undefined', name='toLowercase', doc='whether to convert all characters to lowercase before tokenizing')
RFormula
class pyspark.ml.feature.RFormula(*args, **kwargs)Implements the transforms required for fitting a dataset against an R model formula. Currently we support alimited subset of the R operators, including ‘~’, ‘.’, ‘:’, ‘+’, ‘-‘, ‘*’, and ‘^’.
New in version 1.5.0.
3.3. ML 479
pyspark Documentation, Release master
Notes
Also see the R formula docs.
Examples
>>> df = spark.createDataFrame([... (1.0, 1.0, "a"),... (0.0, 2.0, "b"),... (0.0, 0.0, "a")... ], ["y", "x", "s"])>>> rf = RFormula(formula="y ~ x + s")>>> model = rf.fit(df)>>> model.getLabelCol()'label'>>> model.transform(df).show()+---+---+---+---------+-----+| y| x| s| features|label|+---+---+---+---------+-----+|1.0|1.0| a|[1.0,1.0]| 1.0||0.0|2.0| b|[2.0,0.0]| 0.0||0.0|0.0| a|[0.0,1.0]| 0.0|+---+---+---+---------+-----+...>>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()+---+---+---+--------+-----+| y| x| s|features|label|+---+---+---+--------+-----+|1.0|1.0| a| [1.0]| 1.0||0.0|2.0| b| [2.0]| 0.0||0.0|0.0| a| [0.0]| 0.0|+---+---+---+--------+-----+...>>> rFormulaPath = temp_path + "/rFormula">>> rf.save(rFormulaPath)>>> loadedRF = RFormula.load(rFormulaPath)>>> loadedRF.getFormula() == rf.getFormula()True>>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()True>>> loadedRF.getLabelCol() == rf.getLabelCol()True>>> loadedRF.getHandleInvalid() == rf.getHandleInvalid()True>>> str(loadedRF)'RFormula(y ~ x + s) (uid=...)'>>> modelPath = temp_path + "/rFormulaModel">>> model.save(modelPath)>>> loadedModel = RFormulaModel.load(modelPath)>>> loadedModel.uid == model.uidTrue>>> loadedModel.transform(df).show()+---+---+---+---------+-----+| y| x| s| features|label|+---+---+---+---------+-----+|1.0|1.0| a|[1.0,1.0]| 1.0|
(continues on next page)
480 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
|0.0|2.0| b|[2.0,0.0]| 0.0||0.0|0.0| a|[0.0,1.0]| 0.0|+---+---+---+---------+-----+...>>> str(loadedModel)'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...→˓)'
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFeaturesCol() Gets the value of featuresCol or its default value.getForceIndexLabel() Gets the value of forceIndexLabel.getFormula() Gets the value of formula.getHandleInvalid() Gets the value of handleInvalid or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getStringIndexerOrderType() Gets the value of stringIndexerOrderType
or its default value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.continues on next page
3.3. ML 481
pyspark Documentation, Release master
Table 153 – continued from previous pageset(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setForceIndexLabel(value) Sets the value of forceIndexLabel.setFormula(value) Sets the value of formula.setHandleInvalid(value) Sets the value of handleInvalid.setLabelCol(value) Sets the value of labelCol.setParams(self, \*[, formula, featuresCol, . . . ]) Sets params for RFormula.setStringIndexerOrderType(value) Sets the value of stringIndexerOrderType.write() Returns an MLWriter instance for this ML instance.
Attributes
featuresColforceIndexLabelformulahandleInvalidlabelColparams Returns all params ordered by name.stringIndexerOrderType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
482 Chapter 3. API Reference
pyspark Documentation, Release master
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFeaturesCol()Gets the value of featuresCol or its default value.
getForceIndexLabel()Gets the value of forceIndexLabel.
New in version 2.1.0.
getFormula()Gets the value of formula.
New in version 1.5.0.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getStringIndexerOrderType()Gets the value of stringIndexerOrderType or its default value ‘frequencyDesc’.
New in version 2.3.0.
3.3. ML 483
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
setForceIndexLabel(value)Sets the value of forceIndexLabel.
New in version 2.1.0.
setFormula(value)Sets the value of formula.
New in version 1.5.0.
setHandleInvalid(value)Sets the value of handleInvalid.
setLabelCol(value)Sets the value of labelCol.
setParams(self, \*, formula=None, featuresCol="features", labelCol="label", forceIndexLabel=False,stringIndexerOrderType="frequencyDesc", handleInvalid="error")
Sets params for RFormula.
New in version 1.5.0.
setStringIndexerOrderType(value)Sets the value of stringIndexerOrderType.
New in version 2.3.0.
write()Returns an MLWriter instance for this ML instance.
484 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
forceIndexLabel = Param(parent='undefined', name='forceIndexLabel', doc='Force to index label whether it is numeric or string')
formula = Param(parent='undefined', name='formula', doc='R model formula')
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries. Options are 'skip' (filter out rows with invalid values), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stringIndexerOrderType = Param(parent='undefined', name='stringIndexerOrderType', doc='How to order categories of a string feature column used by StringIndexer. The last category after ordering is dropped when encoding strings. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. The default value is frequencyDesc. When the ordering is set to alphabetDesc, RFormula drops the same category as R when encoding strings.')
RFormulaModel
class pyspark.ml.feature.RFormulaModel(java_model=None)Model fitted by RFormula. Fitting is required to determine the factor levels of formula terms.
New in version 1.5.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFeaturesCol() Gets the value of featuresCol or its default value.getForceIndexLabel() Gets the value of forceIndexLabel.getFormula() Gets the value of formula.getHandleInvalid() Gets the value of handleInvalid or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getStringIndexerOrderType() Gets the value of stringIndexerOrderType
or its default value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.
continues on next page
3.3. ML 485
pyspark Documentation, Release master
Table 155 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
featuresColforceIndexLabelformulahandleInvalidlabelColparams Returns all params ordered by name.stringIndexerOrderType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
486 Chapter 3. API Reference
pyspark Documentation, Release master
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getFeaturesCol()Gets the value of featuresCol or its default value.
getForceIndexLabel()Gets the value of forceIndexLabel.
New in version 2.1.0.
getFormula()Gets the value of formula.
New in version 1.5.0.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getStringIndexerOrderType()Gets the value of stringIndexerOrderType or its default value ‘frequencyDesc’.
New in version 2.3.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
3.3. ML 487
pyspark Documentation, Release master
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
forceIndexLabel = Param(parent='undefined', name='forceIndexLabel', doc='Force to index label whether it is numeric or string')
formula = Param(parent='undefined', name='formula', doc='R model formula')
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries. Options are 'skip' (filter out rows with invalid values), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stringIndexerOrderType = Param(parent='undefined', name='stringIndexerOrderType', doc='How to order categories of a string feature column used by StringIndexer. The last category after ordering is dropped when encoding strings. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. The default value is frequencyDesc. When the ordering is set to alphabetDesc, RFormula drops the same category as R when encoding strings.')
SQLTransformer
class pyspark.ml.feature.SQLTransformer(*args, **kwargs)Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax likeSELECT . . . FROM __THIS__ where __THIS__ represents the underlying table of the input dataset.
New in version 1.6.0.
Examples
>>> df = spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"])>>> sqlTrans = SQLTransformer(... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")>>> sqlTrans.transform(df).head()Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0)>>> sqlTransformerPath = temp_path + "/sql-transformer">>> sqlTrans.save(sqlTransformerPath)>>> loadedSqlTrans = SQLTransformer.load(sqlTransformerPath)>>> loadedSqlTrans.getStatement() == sqlTrans.getStatement()True>>> loadedSqlTrans.transform(df).take(1) == sqlTrans.transform(df).take(1)True
488 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.getStatement() Gets the value of statement or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setParams(self, \*[, statement]) Sets params for this SQLTransformer.setStatement(value) Sets the value of statement.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
params Returns all params ordered by name.statement
3.3. ML 489
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getStatement()Gets the value of statement or its default value.
New in version 1.6.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
490 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setParams(self, \*, statement=None)Sets params for this SQLTransformer.
New in version 1.6.0.
setStatement(value)Sets the value of statement.
New in version 1.6.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
statement = Param(parent='undefined', name='statement', doc='SQL statement')
StandardScaler
class pyspark.ml.feature.StandardScaler(*args, **kwargs)Standardizes features by removing the mean and scaling to unit variance using column summary statistics onthe samples in the training set.
The “unit std” is computed using the corrected sample standard deviation, which is computed as the square rootof the unbiased sample variance.
New in version 1.4.0.
3.3. ML 491
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)],→˓ ["a"])>>> standardScaler = StandardScaler()>>> standardScaler.setInputCol("a")StandardScaler...>>> standardScaler.setOutputCol("scaled")StandardScaler...>>> model = standardScaler.fit(df)>>> model.getInputCol()'a'>>> model.setOutputCol("output")StandardScalerModel...>>> model.meanDenseVector([1.0])>>> model.stdDenseVector([1.4142])>>> model.transform(df).collect()[1].outputDenseVector([1.4142])>>> standardScalerPath = temp_path + "/standard-scaler">>> standardScaler.save(standardScalerPath)>>> loadedStandardScaler = StandardScaler.load(standardScalerPath)>>> loadedStandardScaler.getWithMean() == standardScaler.getWithMean()True>>> loadedStandardScaler.getWithStd() == standardScaler.getWithStd()True>>> modelPath = temp_path + "/standard-scaler-model">>> model.save(modelPath)>>> loadedModel = StandardScalerModel.load(modelPath)>>> loadedModel.std == model.stdTrue>>> loadedModel.mean == model.meanTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
492 Chapter 3. API Reference
pyspark Documentation, Release master
Table 159 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getWithMean() Gets the value of withMean or its default value.getWithStd() Gets the value of withStd or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, withMean, withStd, . . . ]) Sets params for this StandardScaler.setWithMean(value) Sets the value of withMean.setWithStd(value) Sets the value of withStd.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColoutputColparams Returns all params ordered by name.withMeanwithStd
3.3. ML 493
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
494 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getWithMean()Gets the value of withMean or its default value.
New in version 1.4.0.
getWithStd()Gets the value of withStd or its default value.
New in version 1.4.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, withMean=False, withStd=True, inputCol=None, outputCol=None)Sets params for this StandardScaler.
New in version 1.4.0.
3.3. ML 495
pyspark Documentation, Release master
setWithMean(value)Sets the value of withMean.
New in version 1.4.0.
setWithStd(value)Sets the value of withStd.
New in version 1.4.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
withMean = Param(parent='undefined', name='withMean', doc='Center data with mean')
withStd = Param(parent='undefined', name='withStd', doc='Scale to unit standard deviation')
StandardScalerModel
class pyspark.ml.feature.StandardScalerModel(java_model=None)Model fitted by StandardScaler.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.continues on next page
496 Chapter 3. API Reference
pyspark Documentation, Release master
Table 161 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getWithMean() Gets the value of withMean or its default value.getWithStd() Gets the value of withStd or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColmean Mean of the StandardScalerModel.outputColparams Returns all params ordered by name.std Standard deviation of the StandardScalerModel.withMeanwithStd
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in a
3.3. ML 497
pyspark Documentation, Release master
string.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getWithMean()Gets the value of withMean or its default value.
New in version 1.4.0.
getWithStd()Gets the value of withStd or its default value.
New in version 1.4.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
498 Chapter 3. API Reference
pyspark Documentation, Release master
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
meanMean of the StandardScalerModel.
New in version 2.0.0.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stdStandard deviation of the StandardScalerModel.
New in version 2.0.0.
withMean = Param(parent='undefined', name='withMean', doc='Center data with mean')
withStd = Param(parent='undefined', name='withStd', doc='Scale to unit standard deviation')
StopWordsRemover
class pyspark.ml.feature.StopWordsRemover(*args, **kwargs)A feature transformer that filters out stop words from input. Since 3.0.0, StopWordsRemover can filter outmultiple columns at once by setting the inputCols parameter. Note that when both the inputCol andinputCols parameters are set, an Exception will be thrown.
New in version 1.6.0.
3.3. ML 499
pyspark Documentation, Release master
Notes
null values from input array are preserved unless adding null to stopWords explicitly.
Examples
>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"])>>> remover = StopWordsRemover(stopWords=["b"])>>> remover.setInputCol("text")StopWordsRemover...>>> remover.setOutputCol("words")StopWordsRemover...>>> remover.transform(df).head().words == ['a', 'c']True>>> stopWordsRemoverPath = temp_path + "/stopwords-remover">>> remover.save(stopWordsRemoverPath)>>> loadedRemover = StopWordsRemover.load(stopWordsRemoverPath)>>> loadedRemover.getStopWords() == remover.getStopWords()True>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()True>>> loadedRemover.transform(df).take(1) == remover.transform(df).take(1)True>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2→˓"])>>> remover2 = StopWordsRemover(stopWords=["b"])>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])StopWordsRemover...>>> remover2.transform(df2).show()+---------+------+------+------+| text1| text2|words1|words2|+---------+------+------+------+|[a, b, c]|[a, b]|[a, c]| [a]|+---------+------+------+------+...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
500 Chapter 3. API Reference
pyspark Documentation, Release master
Table 163 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCaseSensitive() Gets the value of caseSensitive or its defaultvalue.
getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getLocale() Gets the value of locale.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getStopWords() Gets the value of stopWords or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).loadDefaultStopWords(language) Loads the default stop words for the given language.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCaseSensitive(value) Sets the value of caseSensitive.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setLocale(value) Sets the value of locale.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this StopWordRemover.setStopWords(value) Sets the value of stopWords.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
3.3. ML 501
pyspark Documentation, Release master
Attributes
caseSensitiveinputColinputColslocaleoutputColoutputColsparams Returns all params ordered by name.stopWords
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCaseSensitive()Gets the value of caseSensitive or its default value.
New in version 1.6.0.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
502 Chapter 3. API Reference
pyspark Documentation, Release master
getLocale()Gets the value of locale.
New in version 2.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getStopWords()Gets the value of stopWords or its default value.
New in version 1.6.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
static loadDefaultStopWords(language)Loads the default stop words for the given language. Supported languages: danish, dutch, english, finnish,french, german, hungarian, italian, norwegian, portuguese, russian, spanish, swedish, turkish
New in version 2.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCaseSensitive(value)Sets the value of caseSensitive.
New in version 1.6.0.
setInputCol(value)Sets the value of inputCol.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
3.3. ML 503
pyspark Documentation, Release master
setLocale(value)Sets the value of locale.
New in version 2.4.0.
setOutputCol(value)Sets the value of outputCol.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
setParams(self, \*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, lo-cale=None, inputCols=None, outputCols=None)
Sets params for this StopWordRemover.
New in version 1.6.0.
setStopWords(value)Sets the value of stopWords.
New in version 1.6.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
caseSensitive = Param(parent='undefined', name='caseSensitive', doc='whether to do a case sensitive comparison over the stop words')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
locale = Param(parent='undefined', name='locale', doc='locale of the input. ignored when case sensitive is true')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stopWords = Param(parent='undefined', name='stopWords', doc='The words to be filtered out')
504 Chapter 3. API Reference
pyspark Documentation, Release master
StringIndexer
class pyspark.ml.feature.StringIndexer(*args, **kwargs)A label indexer that maps a string column of labels to an ML column of label indices. If the input column isnumeric, we cast it to string and index the string values. The indices are in [0, numLabels). By default, thisis ordered by label frequencies so the most frequent label gets index 0. The ordering behavior is controlled bysetting stringOrderType. Its default value is ‘frequencyDesc’.
New in version 1.4.0.
Examples
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",... stringOrderType="frequencyDesc")>>> stringIndexer.setHandleInvalid("error")StringIndexer...>>> model = stringIndexer.fit(stringIndDf)>>> model.setHandleInvalid("error")StringIndexerModel...>>> td = model.transform(stringIndDf)>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),... key=lambda x: x[0])[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]>>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.→˓labels)>>> itd = inverter.transform(td)>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).→˓collect()]),... key=lambda x: x[0])[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]>>> stringIndexerPath = temp_path + "/string-indexer">>> stringIndexer.save(stringIndexerPath)>>> loadedIndexer = StringIndexer.load(stringIndexerPath)>>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid()True>>> modelPath = temp_path + "/string-indexer-model">>> model.save(modelPath)>>> loadedModel = StringIndexerModel.load(modelPath)>>> loadedModel.labels == model.labelsTrue>>> indexToStringPath = temp_path + "/index-to-string">>> inverter.save(indexToStringPath)>>> loadedInverter = IndexToString.load(indexToStringPath)>>> loadedInverter.getLabels() == inverter.getLabels()True>>> loadedModel.transform(stringIndDf).take(1) == model.transform(stringIndDf).→˓take(1)True>>> stringIndexer.getStringOrderType()'frequencyDesc'>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",→˓handleInvalid="error",... stringOrderType="alphabetDesc")>>> model = stringIndexer.fit(stringIndDf)>>> td = model.transform(stringIndDf)>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
(continues on next page)
3.3. ML 505
pyspark Documentation, Release master
(continued from previous page)
... key=lambda x: x[0])[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]>>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],... inputCol="label", outputCol="indexed", handleInvalid="error")>>> result = fromlabelsModel.transform(stringIndDf)>>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).→˓collect()]),... key=lambda x: x[0])[(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]>>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),... Row(id=1, label1="b", label2="f"),... Row(id=2, label1="c", label2="e"),... Row(id=3, label1="a", label2="f"),... Row(id=4, label1="a", label2="f"),... Row(id=5, label1="c", label2="f")], 3)>>> multiRowDf = spark.createDataFrame(testData)>>> inputs = ["label1", "label2"]>>> outputs = ["index1", "index2"]>>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)>>> model = stringIndexer.fit(multiRowDf)>>> result = model.transform(multiRowDf)>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.→˓index1,... result.index2).collect()]), key=lambda x: x[0])[(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.→˓0, 0.0)]>>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], [→˓"e", "f"]],... inputCols=inputs, outputCols=outputs)>>> result = fromlabelsModel.transform(multiRowDf)>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.→˓index1,... result.index2).collect()]), key=lambda x: x[0])[(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.→˓0, 1.0)]
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
506 Chapter 3. API Reference
pyspark Documentation, Release master
Table 165 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getStringOrderType() Gets the value of stringOrderType or its default
value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this StringIndexer.setStringOrderType(value) Sets the value of stringOrderType.write() Returns an MLWriter instance for this ML instance.
Attributes
handleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.
continues on next page
3.3. ML 507
pyspark Documentation, Release master
Table 166 – continued from previous pagestringOrderType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
508 Chapter 3. API Reference
pyspark Documentation, Release master
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getStringOrderType()Gets the value of stringOrderType or its default value ‘frequencyDesc’.
New in version 2.3.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setHandleInvalid(value)Sets the value of handleInvalid.
3.3. ML 509
pyspark Documentation, Release master
setInputCol(value)Sets the value of inputCol.
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
setParams(self, \*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleIn-valid="error", stringOrderType="frequencyDesc")
Sets params for this StringIndexer.
New in version 1.4.0.
setStringOrderType(value)Sets the value of stringOrderType.
New in version 2.3.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid data (unseen or NULL values) in features and label column of string type. Options are 'skip' (filter out rows with invalid data), error (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stringOrderType = Param(parent='undefined', name='stringOrderType', doc='How to order labels of string column. The first label after ordering is assigned an index of 0. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. Default is frequencyDesc. In case of equal frequency when under frequencyDesc/Asc, the strings are further sorted alphabetically')
StringIndexerModel
class pyspark.ml.feature.StringIndexerModel(java_model=None)Model fitted by StringIndexer.
New in version 1.4.0.
510 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
from_arrays_of_labels(arrayOfLabels, in-putCols)
Construct the model directly from an array of arrayof label strings, requires an active SparkContext.
from_labels(labels, inputCol[, outputCol, . . . ]) Construct the model directly from an array of labelstrings, requires an active SparkContext.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getStringOrderType() Gets the value of stringOrderType or its default
value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
3.3. ML 511
pyspark Documentation, Release master
Attributes
handleInvalidinputColinputColslabels Ordered list of labels, corresponding to indices to be
assigned.labelsArray Array of ordered list of labels, corresponding to in-
dices to be assigned for each input column.outputColoutputColsparams Returns all params ordered by name.stringOrderType
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
classmethod from_arrays_of_labels(arrayOfLabels, inputCols, outputCols=None, han-dleInvalid=None)
Construct the model directly from an array of array of label strings, requires an active SparkContext.
New in version 3.0.0.
512 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod from_labels(labels, inputCol, outputCol=None, handleInvalid=None)Construct the model directly from an array of label strings, requires an active SparkContext.
New in version 2.4.0.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getOutputCols()Gets the value of outputCols or its default value.
getParam(paramName)Gets a param by its name.
getStringOrderType()Gets the value of stringOrderType or its default value ‘frequencyDesc’.
New in version 2.3.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setHandleInvalid(value)Sets the value of handleInvalid.
New in version 2.4.0.
setInputCol(value)Sets the value of inputCol.
3.3. ML 513
pyspark Documentation, Release master
setInputCols(value)Sets the value of inputCols.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
setOutputCols(value)Sets the value of outputCols.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid data (unseen or NULL values) in features and label column of string type. Options are 'skip' (filter out rows with invalid data), error (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
labelsOrdered list of labels, corresponding to indices to be assigned.
Deprecated since version 3.1.0: It will be removed in future versions. Use labelsArray method instead.
New in version 1.5.0.
labelsArrayArray of ordered list of labels, corresponding to indices to be assigned for each input column.
New in version 3.0.2.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
stringOrderType = Param(parent='undefined', name='stringOrderType', doc='How to order labels of string column. The first label after ordering is assigned an index of 0. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. Default is frequencyDesc. In case of equal frequency when under frequencyDesc/Asc, the strings are further sorted alphabetically')
514 Chapter 3. API Reference
pyspark Documentation, Release master
Tokenizer
class pyspark.ml.feature.Tokenizer(*args, **kwargs)A tokenizer that converts the input string to lowercase and then splits it by white spaces.
New in version 1.3.0.
Examples
>>> df = spark.createDataFrame([("a b c",)], ["text"])>>> tokenizer = Tokenizer(outputCol="words")>>> tokenizer.setInputCol("text")Tokenizer...>>> tokenizer.transform(df).head()Row(text='a b c', words=['a', 'b', 'c'])>>> # Change a parameter.>>> tokenizer.setParams(outputCol="tokens").transform(df).head()Row(text='a b c', tokens=['a', 'b', 'c'])>>> # Temporarily modify a parameter.>>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()Row(text='a b c', words=['a', 'b', 'c'])>>> tokenizer.transform(df).head()Row(text='a b c', tokens=['a', 'b', 'c'])>>> # Must use keyword arguments to specify params.>>> tokenizer.setParams("text")Traceback (most recent call last):
...TypeError: Method setParams forces keyword arguments.>>> tokenizerPath = temp_path + "/tokenizer">>> tokenizer.save(tokenizerPath)>>> loadedTokenizer = Tokenizer.load(tokenizerPath)>>> loadedTokenizer.transform(df).head().tokens == tokenizer.transform(df).head().→˓tokensTrue
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
continues on next page
3.3. ML 515
pyspark Documentation, Release master
Table 169 – continued from previous pagegetInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol]) Sets params for this Tokenizer.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
516 Chapter 3. API Reference
pyspark Documentation, Release master
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inputCol=None, outputCol=None)Sets params for this Tokenizer.
New in version 1.3.0.
3.3. ML 517
pyspark Documentation, Release master
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
VarianceThresholdSelector
class pyspark.ml.feature.VarianceThresholdSelector(*args, **kwargs)Feature selector that removes all low-variance features. Features with a variance not greater than the thresholdwill be removed. The default is to keep all features with non-zero variance, i.e. remove the features that havethe same value in all samples.
New in version 3.1.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]),),... (Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]),),... (Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]),),... (Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]),),... (Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]),),... (Vectors.dense([8.0, 9.0, 6.0, 0.0, 0.0, 0.0]),)],... ["features"])>>> selector = VarianceThresholdSelector(varianceThreshold=8.2, outputCol=→˓"selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")VarianceThresholdSelectorModel...>>> model.transform(df).head().selectedFeaturesDenseVector([6.0, 7.0, 0.0])
(continues on next page)
518 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> model.selectedFeatures[0, 3, 5]>>> varianceThresholdSelectorPath = temp_path + "/variance-threshold-selector">>> selector.save(varianceThresholdSelectorPath)>>> loadedSelector = VarianceThresholdSelector.load(varianceThresholdSelectorPath)>>> loadedSelector.getVarianceThreshold() == selector.getVarianceThreshold()True>>> modelPath = temp_path + "/variance-threshold-selector-model">>> model.save(modelPath)>>> loadedModel = VarianceThresholdSelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFeaturesCol() Gets the value of featuresCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getVarianceThreshold() Gets the value of varianceThreshold or its default
value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.
continues on next page
3.3. ML 519
pyspark Documentation, Release master
Table 171 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, featuresCol, . . . ]) Sets params for this VarianceThresholdSelector.setVarianceThreshold(value) Sets the value of varianceThreshold.write() Returns an MLWriter instance for this ML instance.
Attributes
featuresColoutputColparams Returns all params ordered by name.varianceThreshold
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
520 Chapter 3. API Reference
pyspark Documentation, Release master
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFeaturesCol()Gets the value of featuresCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getVarianceThreshold()Gets the value of varianceThreshold or its default value.
New in version 3.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 521
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.1.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.1.0.
setParams(self, \*, featuresCol="features", outputCol=None, varianceThreshold=0.0)Sets params for this VarianceThresholdSelector.
New in version 3.1.0.
setVarianceThreshold(value)Sets the value of varianceThreshold.
New in version 3.1.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
varianceThreshold = Param(parent='undefined', name='varianceThreshold', doc='Param for variance threshold. Features with a variance not greater than this threshold will be removed. The default value is 0.0.')
VarianceThresholdSelectorModel
class pyspark.ml.feature.VarianceThresholdSelectorModel(java_model=None)Model fitted by VarianceThresholdSelector.
New in version 3.1.0.
522 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFeaturesCol() Gets the value of featuresCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getVarianceThreshold() Gets the value of varianceThreshold or its default
value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
3.3. ML 523
pyspark Documentation, Release master
Attributes
featuresColoutputColparams Returns all params ordered by name.selectedFeatures List of indices to select (filter).varianceThreshold
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getFeaturesCol()Gets the value of featuresCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
524 Chapter 3. API Reference
pyspark Documentation, Release master
getVarianceThreshold()Gets the value of varianceThreshold or its default value.
New in version 3.1.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.1.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.1.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 525
pyspark Documentation, Release master
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
selectedFeaturesList of indices to select (filter).
New in version 3.1.0.
varianceThreshold = Param(parent='undefined', name='varianceThreshold', doc='Param for variance threshold. Features with a variance not greater than this threshold will be removed. The default value is 0.0.')
VectorAssembler
class pyspark.ml.feature.VectorAssembler(*args, **kwargs)A feature transformer that merges multiple columns into a vector column.
New in version 1.4.0.
Examples
>>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"])>>> vecAssembler = VectorAssembler(outputCol="features")>>> vecAssembler.setInputCols(["a", "b", "c"])VectorAssembler...>>> vecAssembler.transform(df).head().featuresDenseVector([1.0, 0.0, 3.0])>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqsDenseVector([1.0, 0.0, 3.0])>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector→˓"}>>> vecAssembler.transform(df, params).head().vectorDenseVector([0.0, 1.0])>>> vectorAssemblerPath = temp_path + "/vector-assembler">>> vecAssembler.save(vectorAssemblerPath)>>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)>>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).→˓head().freqsTrue>>> dfWithNullsAndNaNs = spark.createDataFrame(... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b",→˓"c"])>>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features→˓",... handleInvalid="keep")>>> vecAssembler2.transform(dfWithNullsAndNaNs).show()+---+---+----+-------------+| a| b| c| features|+---+---+----+-------------+|1.0|2.0|null|[1.0,2.0,NaN]||3.0|NaN| 4.0|[3.0,NaN,4.0]||5.0|6.0| 7.0|[5.0,6.0,7.0]|
(continues on next page)
526 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
+---+---+----+-------------+...>>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).→˓show()+---+---+---+-------------+| a| b| c| features|+---+---+---+-------------+|5.0|6.0|7.0|[5.0,6.0,7.0]|+---+---+---+-------------+...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCols, outputCol, . . . ]) Sets params for this VectorAssembler.
continues on next page
3.3. ML 527
pyspark Documentation, Release master
Table 175 – continued from previous pagetransform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
handleInvalidinputColsoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCols()Gets the value of inputCols or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
528 Chapter 3. API Reference
pyspark Documentation, Release master
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setHandleInvalid(value)Sets the value of handleInvalid.
setInputCols(value)Sets the value of inputCols.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, inputCols=None, outputCol=None, handleInvalid="error")Sets params for this VectorAssembler.
New in version 1.4.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 529
pyspark Documentation, Release master
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data (NULL and NaN values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the output). Column lengths are taken from the size of ML Attribute Group, which can be set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred from first rows of the data since it is safe to do so but only in case of 'error' or 'skip').")
inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
VectorIndexer
class pyspark.ml.feature.VectorIndexer(*args, **kwargs)Class for indexing categorical feature columns in a dataset of Vector.
This has 2 usage modes:
• Automatically identify categorical features (default behavior)
– This helps process a dataset of unknown vectors into a dataset with some continuous featuresand some categorical features. The choice between continuous and categorical is based upon amaxCategories parameter.
– Set maxCategories to the maximum number of categorical any categorical feature should have.
– E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. If maxCate-gories = 2, then feature 0 will be declared categorical and use indices {0, 1}, and feature 1 willbe declared continuous.
• Index all features, if all features are categorical
– If maxCategories is set to be very large, then this will build an index of unique values for allfeatures.
– Warning: This can cause problems if features are continuous since this will collect ALL uniquevalues to the driver.
– E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. If maxCate-gories >= 3, then both features will be declared categorical.
This returns a model which can transform categorical features to use 0-based indices.
Index stability:
• This is not guaranteed to choose the same category index across multiple runs.
• If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. Thismaintains vector sparsity.
• More stability may be added in the future.
TODO: Future extensions: The following functionality is planned for the future:
• Preserve metadata in transform; if a feature’s metadata is already present, do not recompute.
• Specify certain features to not index, either via a parameter or via existing metadata.
• Add warning if a categorical feature has only 1 category.
New in version 1.4.0.
530 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"])>>> indexer = VectorIndexer(maxCategories=2, inputCol="a")>>> indexer.setOutputCol("indexed")VectorIndexer...>>> model = indexer.fit(df)>>> indexer.getHandleInvalid()'error'>>> model.setOutputCol("output")VectorIndexerModel...>>> model.transform(df).head().outputDenseVector([1.0, 0.0])>>> model.numFeatures2>>> model.categoryMaps{0: {0.0: 0, -1.0: 1}}>>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].testDenseVector([0.0, 1.0])>>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}>>> model2 = indexer.fit(df, params)>>> model2.transform(df).head().vectorDenseVector([1.0, 0.0])>>> vectorIndexerPath = temp_path + "/vector-indexer">>> indexer.save(vectorIndexerPath)>>> loadedIndexer = VectorIndexer.load(vectorIndexerPath)>>> loadedIndexer.getMaxCategories() == indexer.getMaxCategories()True>>> modelPath = temp_path + "/vector-indexer-model">>> model.save(modelPath)>>> loadedModel = VectorIndexerModel.load(modelPath)>>> loadedModel.numFeatures == model.numFeaturesTrue>>> loadedModel.categoryMaps == model.categoryMapsTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True>>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])>>> indexer.getHandleInvalid()'error'>>> model3 = indexer.setHandleInvalid("skip").fit(df)>>> model3.transform(dfWithInvalid).count()0>>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)>>> model4.transform(dfWithInvalid).head().indexedDenseVector([2.0, 1.0])
3.3. ML 531
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxCategories() Gets the value of maxCategories or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setMaxCategories(value) Sets the value of maxCategories.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, maxCategories, . . . ]) Sets params for this VectorIndexer.write() Returns an MLWriter instance for this ML instance.
532 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
handleInvalidinputColmaxCategoriesoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
3.3. ML 533
pyspark Documentation, Release master
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getMaxCategories()Gets the value of maxCategories or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
534 Chapter 3. API Reference
pyspark Documentation, Release master
setHandleInvalid(value)Sets the value of handleInvalid.
setInputCol(value)Sets the value of inputCol.
setMaxCategories(value)Sets the value of maxCategories.
New in version 1.4.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")Sets params for this VectorIndexer.
New in version 1.4.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data (unseen labels or NULL values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index of the number of categories of the feature).")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxCategories = Param(parent='undefined', name='maxCategories', doc='Threshold for the number of values a categorical feature can take (>= 2). If a feature is found to have > maxCategories values, then it is declared continuous.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
VectorIndexerModel
class pyspark.ml.feature.VectorIndexerModel(java_model=None)Model fitted by VectorIndexer.
Transform categorical features to use 0-based indices instead of their original values.
• Categorical features are mapped to indices.
• Continuous features (columns) are left unchanged.
This also appends metadata to the output column, marking features as Numeric (continuous), Nominal (categor-ical), or Binary (either continuous or categorical). Non-ML metadata is not carried over from the input to theoutput column.
This maintains vector sparsity.
New in version 1.4.0.
3.3. ML 535
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxCategories() Gets the value of maxCategories or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
536 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
categoryMaps Feature value index.handleInvalidinputColmaxCategoriesnumFeatures Number of features, i.e., length of Vectors which this
transforms.outputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getMaxCategories()Gets the value of maxCategories or its default value.
New in version 1.4.0.
3.3. ML 537
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
New in version 3.0.0.
setOutputCol(value)Sets the value of outputCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
538 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
categoryMapsFeature value index. Keys are categorical feature indices (column indices). Values are maps from originalfeatures values to 0-based category indices. If a feature is not in this map, it is treated as continuous.
New in version 1.4.0.
handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data (unseen labels or NULL values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index of the number of categories of the feature).")
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxCategories = Param(parent='undefined', name='maxCategories', doc='Threshold for the number of values a categorical feature can take (>= 2). If a feature is found to have > maxCategories values, then it is declared continuous.')
numFeaturesNumber of features, i.e., length of Vectors which this transforms.
New in version 1.4.0.
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
VectorSizeHint
class pyspark.ml.feature.VectorSizeHint(*args, **kwargs)A feature transformer that adds size information to the metadata of a vector column. VectorAssembler needssize information for its input columns and cannot be used on streaming dataframes without this metadata.
New in version 2.3.0.
Notes
VectorSizeHint modifies inputCol to include size metadata and does not have an outputCol.
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml import Pipeline, PipelineModel>>> data = [(Vectors.dense([1., 2., 3.]), 4.)]>>> df = spark.createDataFrame(data, ["vector", "float"])>>>>>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip")>>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol=→˓"assembled")>>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])>>>>>> pipelineModel = pipeline.fit(df)>>> pipelineModel.transform(df).head().assembledDenseVector([1.0, 2.0, 3.0, 4.0])>>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline">>> pipelineModel.save(vectorSizeHintPath)>>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)>>> loaded = loadedPipeline.transform(df).head().assembled>>> expected = pipelineModel.transform(df).head().assembled
(continues on next page)
3.3. ML 539
pyspark Documentation, Release master
(continued from previous page)
>>> loaded == expectedTrue
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSize() Gets size param, the size of vectors in inputCol.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setParams(self, \*[, inputCol, size, . . . ]) Sets params for this VectorSizeHint.setSize(value) Sets size param, the size of vectors in inputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
540 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
handleInvalidinputColparams Returns all params ordered by name.size
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getHandleInvalid()Gets the value of handleInvalid or its default value.
getInputCol()Gets the value of inputCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSize()Gets size param, the size of vectors in inputCol.
3.3. ML 541
pyspark Documentation, Release master
New in version 2.3.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setHandleInvalid(value)Sets the value of handleInvalid.
setInputCol(value)Sets the value of inputCol.
setParams(self, \*, inputCol=None, size=None, handleInvalid="error")Sets params for this VectorSizeHint.
New in version 2.3.0.
setSize(value)Sets size param, the size of vectors in inputCol.
New in version 2.3.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
542 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
handleInvalid = Param(parent='undefined', name='handleInvalid', doc='How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an error) and `optimistic` (do not check the vector size, and keep all rows). `error` by default.')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
size = Param(parent='undefined', name='size', doc='Size of vectors in column.')
VectorSlicer
class pyspark.ml.feature.VectorSlicer(*args, **kwargs)This class takes a feature vector and outputs a new feature vector with a subarray of the original features.
The subset of features can be specified with either indices (setIndices()) or names (setNames()). At least onefeature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indicesand names.
The output vector will order features with the selected indices first (in the order given), followed by the selectednames (in the order given).
New in version 1.6.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),),... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),),... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"])>>> vs = VectorSlicer(outputCol="sliced", indices=[1, 4])>>> vs.setInputCol("features")VectorSlicer...>>> vs.transform(df).head().slicedDenseVector([2.3, 1.0])>>> vectorSlicerPath = temp_path + "/vector-slicer">>> vs.save(vectorSlicerPath)>>> loadedVs = VectorSlicer.load(vectorSlicerPath)>>> loadedVs.getIndices() == vs.getIndices()True>>> loadedVs.getNames() == vs.getNames()True>>> loadedVs.transform(df).take(1) == vs.transform(df).take(1)True
3.3. ML 543
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getIndices() Gets the value of indices or its default value.getInputCol() Gets the value of inputCol or its default value.getNames() Gets the value of names or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setIndices(value) Sets the value of indices.setInputCol(value) Sets the value of inputCol.setNames(value) Sets the value of names.setOutputCol(value) Sets the value of outputCol.setParams(*args, **kwargs) setParams(self, *, inputCol=None, outputCol=None,
indices=None, names=None): Sets params for thisVectorSlicer.
transform(dataset[, params]) Transforms the input dataset with optional parame-ters.
write() Returns an MLWriter instance for this ML instance.
544 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
indicesinputColnamesoutputColparams Returns all params ordered by name.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getIndices()Gets the value of indices or its default value.
New in version 1.6.0.
getInputCol()Gets the value of inputCol or its default value.
getNames()Gets the value of names or its default value.
New in version 1.6.0.
3.3. ML 545
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setIndices(value)Sets the value of indices.
New in version 1.6.0.
setInputCol(value)Sets the value of inputCol.
setNames(value)Sets the value of names.
New in version 1.6.0.
setOutputCol(value)Sets the value of outputCol.
setParams(*args, **kwargs)setParams(self, *, inputCol=None, outputCol=None, indices=None, names=None): Sets params for thisVectorSlicer.
New in version 1.6.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
546 Chapter 3. API Reference
pyspark Documentation, Release master
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
indices = Param(parent='undefined', name='indices', doc='An array of indices to select features from a vector column. There can be no overlap with names.')
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
names = Param(parent='undefined', name='names', doc='An array of feature names to select features from a vector column. These names must be specified by ML org.apache.spark.ml.attribute.Attribute. There can be no overlap with indices.')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
Word2Vec
class pyspark.ml.feature.Word2Vec(*args, **kwargs)Word2Vec trains a model of Map(String, Vector), i.e. transforms a word into a code for further natural languageprocessing or machine learning process.
New in version 1.4.0.
Examples
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")>>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])>>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol=→˓"model")>>> word2Vec.setMaxIter(10)Word2Vec...>>> word2Vec.getMaxIter()10>>> word2Vec.clear(word2Vec.maxIter)>>> model = word2Vec.fit(doc)>>> model.getMinCount()5>>> model.setInputCol("sentence")Word2VecModel...>>> model.getVectors().show()+----+--------------------+|word| vector|+----+--------------------+| a|[0.09511678665876...|| b|[-1.2028766870498...|| c|[0.30153277516365...|+----+--------------------+...>>> model.findSynonymsArray("a", 2)
(continues on next page)
3.3. ML 547
pyspark Documentation, Release master
(continued from previous page)
[('b', 0.015859870240092278), ('c', -0.5680795907974243)]>>> from pyspark.sql.functions import format_number as fmt>>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias(→˓"similarity")).show()+----+----------+|word|similarity|+----+----------+| b| 0.01586|| c| -0.56808|+----+----------+...>>> model.transform(doc).head().modelDenseVector([-0.4833, 0.1855, -0.273, -0.0509, -0.4769])>>> word2vecPath = temp_path + "/word2vec">>> word2Vec.save(word2vecPath)>>> loadedWord2Vec = Word2Vec.load(word2vecPath)>>> loadedWord2Vec.getVectorSize() == word2Vec.getVectorSize()True>>> loadedWord2Vec.getNumPartitions() == word2Vec.getNumPartitions()True>>> loadedWord2Vec.getMinCount() == word2Vec.getMinCount()True>>> modelPath = temp_path + "/word2vec-model">>> model.save(modelPath)>>> loadedModel = Word2VecModel.load(modelPath)>>> loadedModel.getVectors().first().word == model.getVectors().first().wordTrue>>> loadedModel.getVectors().first().vector == model.getVectors().first().vectorTrue>>> loadedModel.transform(doc).take(1) == model.transform(doc).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
continues on next page
548 Chapter 3. API Reference
pyspark Documentation, Release master
Table 185 – continued from previous pagefitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map
in paramMaps.getInputCol() Gets the value of inputCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxSentenceLength() Gets the value of maxSentenceLength or its default
value.getMinCount() Gets the value of minCount or its default value.getNumPartitions() Gets the value of numPartitions or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getVectorSize() Gets the value of vectorSize or its default value.getWindowSize() Gets the value of windowSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMaxIter(value) Sets the value of maxIter.setMaxSentenceLength(value) Sets the value of maxSentenceLength.setMinCount(value) Sets the value of minCount.setNumPartitions(value) Sets the value of numPartitions.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minCount, . . . ]) Sets params for this Word2Vec.setSeed(value) Sets the value of seed.setStepSize(value) Sets the value of stepSize.setVectorSize(value) Sets the value of vectorSize.setWindowSize(value) Sets the value of windowSize.write() Returns an MLWriter instance for this ML instance.
Attributes
inputColmaxItermaxSentenceLengthminCountnumPartitionsoutputCol
continues on next page
3.3. ML 549
pyspark Documentation, Release master
Table 186 – continued from previous pageparams Returns all params ordered by name.seedstepSizevectorSizewindowSize
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
550 Chapter 3. API Reference
pyspark Documentation, Release master
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getInputCol()Gets the value of inputCol or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMaxSentenceLength()Gets the value of maxSentenceLength or its default value.
New in version 2.0.0.
getMinCount()Gets the value of minCount or its default value.
New in version 1.4.0.
getNumPartitions()Gets the value of numPartitions or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getStepSize()Gets the value of stepSize or its default value.
getVectorSize()Gets the value of vectorSize or its default value.
New in version 1.4.0.
getWindowSize()Gets the value of windowSize or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 551
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setMaxIter(value)Sets the value of maxIter.
setMaxSentenceLength(value)Sets the value of maxSentenceLength.
New in version 2.0.0.
setMinCount(value)Sets the value of minCount.
New in version 1.4.0.
setNumPartitions(value)Sets the value of numPartitions.
New in version 1.4.0.
setOutputCol(value)Sets the value of outputCol.
setParams(self, \*, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, input-Col=None, outputCol=None, windowSize=5, maxSentenceLength=1000)
Sets params for this Word2Vec.
New in version 1.4.0.
setSeed(value)Sets the value of seed.
setStepSize(value)Sets the value of stepSize.
New in version 1.4.0.
setVectorSize(value)Sets the value of vectorSize.
New in version 1.4.0.
552 Chapter 3. API Reference
pyspark Documentation, Release master
setWindowSize(value)Sets the value of windowSize.
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
maxSentenceLength = Param(parent='undefined', name='maxSentenceLength', doc='Maximum length (in words) of each sentence in the input data. Any sentence longer than this threshold will be divided into chunks up to the size.')
minCount = Param(parent='undefined', name='minCount', doc="the minimum number of times a token must appear to be included in the word2vec model's vocabulary")
numPartitions = Param(parent='undefined', name='numPartitions', doc='number of partitions for sentences of words')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
vectorSize = Param(parent='undefined', name='vectorSize', doc='the dimension of codes after transforming from words')
windowSize = Param(parent='undefined', name='windowSize', doc='the window size (context words from [-window, window]). Default value is 5')
Word2VecModel
class pyspark.ml.feature.Word2VecModel(java_model=None)Model fitted by Word2Vec.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
3.3. ML 553
pyspark Documentation, Release master
Table 187 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
findSynonyms(word, num) Find “num” number of words closest in similarity to“word”.
findSynonymsArray(word, num) Find “num” number of words closest in similarity to“word”.
getInputCol() Gets the value of inputCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxSentenceLength() Gets the value of maxSentenceLength or its default
value.getMinCount() Gets the value of minCount or its default value.getNumPartitions() Gets the value of numPartitions or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getVectorSize() Gets the value of vectorSize or its default value.getVectors() Returns the vector representation of the words as a
dataframe with two fields, word and vector.getWindowSize() Gets the value of windowSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
554 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
inputColmaxItermaxSentenceLengthminCountnumPartitionsoutputColparams Returns all params ordered by name.seedstepSizevectorSizewindowSize
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
findSynonyms(word, num)Find “num” number of words closest in similarity to “word”. word can be a string or vector representation.Returns a dataframe with two fields word and similarity (which gives the cosine similarity).
New in version 1.5.0.
3.3. ML 555
pyspark Documentation, Release master
findSynonymsArray(word, num)Find “num” number of words closest in similarity to “word”. word can be a string or vector representation.Returns an array with two fields word and similarity (which gives the cosine similarity).
New in version 2.3.0.
getInputCol()Gets the value of inputCol or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMaxSentenceLength()Gets the value of maxSentenceLength or its default value.
New in version 2.0.0.
getMinCount()Gets the value of minCount or its default value.
New in version 1.4.0.
getNumPartitions()Gets the value of numPartitions or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getOutputCol()Gets the value of outputCol or its default value.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getStepSize()Gets the value of stepSize or its default value.
getVectorSize()Gets the value of vectorSize or its default value.
New in version 1.4.0.
getVectors()Returns the vector representation of the words as a dataframe with two fields, word and vector.
New in version 1.5.0.
getWindowSize()Gets the value of windowSize or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
556 Chapter 3. API Reference
pyspark Documentation, Release master
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setInputCol(value)Sets the value of inputCol.
setOutputCol(value)Sets the value of outputCol.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
maxSentenceLength = Param(parent='undefined', name='maxSentenceLength', doc='Maximum length (in words) of each sentence in the input data. Any sentence longer than this threshold will be divided into chunks up to the size.')
minCount = Param(parent='undefined', name='minCount', doc="the minimum number of times a token must appear to be included in the word2vec model's vocabulary")
numPartitions = Param(parent='undefined', name='numPartitions', doc='number of partitions for sentences of words')
outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
vectorSize = Param(parent='undefined', name='vectorSize', doc='the dimension of codes after transforming from words')
3.3. ML 557
pyspark Documentation, Release master
windowSize = Param(parent='undefined', name='windowSize', doc='the window size (context words from [-window, window]). Default value is 5')
3.3.4 Classification
LinearSVC(*args, **kwargs) This binary classifier optimizes the Hinge Loss usingthe OWLQN optimizer.
LinearSVCModel([java_model]) Model fitted by LinearSVC.LinearSVCSummary([java_obj]) Abstraction for LinearSVC Results for a given model.LinearSVCTrainingSummary([java_obj]) Abstraction for LinearSVC Training results.LogisticRegression(*args, **kwargs) Logistic regression.LogisticRegressionModel([java_model]) Model fitted by LogisticRegression.LogisticRegressionSummary([java_obj]) Abstraction for Logistic Regression Results for a given
model.LogisticRegressionTrainingSummary([java_obj])Abstraction for multinomial Logistic Regression Train-
ing results.BinaryLogisticRegressionSummary([java_obj])Binary Logistic regression results for a given model.BinaryLogisticRegressionTrainingSummary([. . . ])Binary Logistic regression training results for a given
model.DecisionTreeClassifier(*args, **kwargs) Decision tree learning algorithm for classification.It
supports both binary and multiclass labels, as well asboth continuous and categorical features..
DecisionTreeClassificationModel([java_model])Model fitted by DecisionTreeClassifier.GBTClassifier(*args, **kwargs) Gradient-Boosted Trees (GBTs) learning algorithm for
classification.It supports binary labels, as well as bothcontinuous and categorical features..
GBTClassificationModel([java_model]) Model fitted by GBTClassifier.RandomForestClassifier(*args, **kwargs) Random Forest learning algorithm for classification.It
supports both binary and multiclass labels, as well asboth continuous and categorical features..
RandomForestClassificationModel([java_model])Model fitted by RandomForestClassifier.RandomForestClassificationSummary([java_obj])Abstraction for RandomForestClassification Results for
a given model.RandomForestClassificationTrainingSummary([. . . ])Abstraction for RandomForestClassificationTraining
Training results.BinaryRandomForestClassificationSummary([. . . ])BinaryRandomForestClassification results for a given
model.BinaryRandomForestClassificationTrainingSummary([. . . ])BinaryRandomForestClassification training results for a
given model.NaiveBayes(*args, **kwargs) Naive Bayes Classifiers.NaiveBayesModel([java_model]) Model fitted by NaiveBayes.MultilayerPerceptronClassifier(*args,**kwargs)
Classifier trainer based on the Multilayer Perceptron.
MultilayerPerceptronClassificationModel([. . . ])Model fitted by MultilayerPerceptronClassifier.MultilayerPerceptronClassificationSummary([. . . ])Abstraction for MultilayerPerceptronClassifier Results
for a given model.MultilayerPerceptronClassificationTrainingSummary([. . . ])Abstraction for MultilayerPerceptronClassifier Training
results.OneVsRest(*args, **kwargs) Reduction of Multiclass Classification to Binary Classi-
fication.OneVsRestModel(models) Model fitted by OneVsRest.
continues on next page
558 Chapter 3. API Reference
pyspark Documentation, Release master
Table 189 – continued from previous pageFMClassifier(*args, **kwargs) Factorization Machines learning algorithm for classifi-
cation.FMClassificationModel([java_model]) Model fitted by FMClassifier.FMClassificationSummary([java_obj]) Abstraction for FMClassifier Results for a given model.FMClassificationTrainingSummary([java_obj])Abstraction for FMClassifier Training results.
LinearSVC
class pyspark.ml.classification.LinearSVC(*args, **kwargs)This binary classifier optimizes the Hinge Loss using the OWLQN optimizer. Only supports L2 regularizationcurrently.
New in version 2.2.0.
Notes
Linear SVM Classifier
Examples
>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> df = sc.parallelize([... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()>>> svm = LinearSVC()>>> svm.getMaxIter()100>>> svm.setMaxIter(5)LinearSVC...>>> svm.getMaxIter()5>>> svm.getRegParam()0.0>>> svm.setRegParam(0.01)LinearSVC...>>> svm.getRegParam()0.01>>> model = svm.fit(df)>>> model.setPredictionCol("newPrediction")LinearSVCModel...>>> model.getPredictionCol()'newPrediction'>>> model.setThreshold(0.5)LinearSVCModel...>>> model.getThreshold()0.5>>> model.getMaxBlockSizeInMB()0.0>>> model.coefficientsDenseVector([0.0, -0.2792, -0.1833])>>> model.intercept1.0206118982229047
(continues on next page)
3.3. ML 559
pyspark Documentation, Release master
(continued from previous page)
>>> model.numClasses2>>> model.numFeatures3>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()>>> model.predict(test0.head().features)1.0>>> model.predictRaw(test0.head().features)DenseVector([-1.4831, 1.4831])>>> result = model.transform(test0).head()>>> result.newPrediction1.0>>> result.rawPredictionDenseVector([-1.4831, 1.4831])>>> svm_path = temp_path + "/svm">>> svm.save(svm_path)>>> svm2 = LinearSVC.load(svm_path)>>> svm2.getMaxIter()5>>> model_path = temp_path + "/svm_model">>> model.save(model_path)>>> model2 = LinearSVCModel.load(model_path)>>> model.coefficients[0] == model2.coefficients[0]True>>> model.intercept == model2.interceptTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
continues on next page
560 Chapter 3. API Reference
pyspark Documentation, Release master
Table 190 – continued from previous pagegetFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Gets the value of threshold or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-
belCol=”label”, predictionCol=”prediction”,maxIter=100, regParam=0.0, tol=1e-6, rawPredic-tionCol=”rawPrediction”, fitIntercept=True, stan-dardization=True, threshold=0.0, weightCol=None,aggregationDepth=2, maxBlockSizeInMB=0.0):Sets params for Linear SVM Classifier.
setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setRegParam(value) Sets the value of regParam.setStandardization(value) Sets the value of standardization.setThreshold(value) Sets the value of threshold.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
3.3. ML 561
pyspark Documentation, Release master
Attributes
aggregationDepthfeaturesColfitInterceptlabelColmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColrawPredictionColregParamstandardizationthresholdtolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
562 Chapter 3. API Reference
pyspark Documentation, Release master
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getStandardization()Gets the value of standardization or its default value.
3.3. ML 563
pyspark Documentation, Release master
getThreshold()Gets the value of threshold or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAggregationDepth(value)Sets the value of aggregationDepth.
New in version 2.2.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setFitIntercept(value)Sets the value of fitIntercept.
New in version 2.2.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.
New in version 3.1.0.
setMaxIter(value)Sets the value of maxIter.
New in version 2.2.0.
setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, maxIter=100,
564 Chapter 3. API Reference
pyspark Documentation, Release master
regParam=0.0, tol=1e-6, rawPredictionCol=”rawPrediction”, fitIntercept=True, standardization=True,threshold=0.0, weightCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0): Sets params for LinearSVM Classifier.
New in version 2.2.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setRegParam(value)Sets the value of regParam.
New in version 2.2.0.
setStandardization(value)Sets the value of standardization.
New in version 2.2.0.
setThreshold(value)Sets the value of threshold.
New in version 2.2.0.
setTol(value)Sets the value of tol.
New in version 2.2.0.
setWeightCol(value)Sets the value of weightCol.
New in version 2.2.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
3.3. ML 565
pyspark Documentation, Release master
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')
threshold = Param(parent='undefined', name='threshold', doc='The threshold in binary classification applied to the linear model prediction. This threshold can be any real number, where Inf will make all predictions 0.0 and -Inf will make all predictions 1.0.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
LinearSVCModel
class pyspark.ml.classification.LinearSVCModel(java_model=None)Model fitted by LinearSVC.
New in version 2.2.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Gets the value of threshold or its default value.getTol() Gets the value of tol or its default value.
continues on next page
566 Chapter 3. API Reference
pyspark Documentation, Release master
Table 192 – continued from previous pagegetWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThreshold(value) Sets the value of threshold.summary() Gets summary (e.g.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthcoefficients Model coefficients of Linear SVM Classifier.featuresColfitIntercepthasSummary Indicates whether a training summary exists for this
model instance.intercept Model intercept of Linear SVM Classifier.labelColmaxBlockSizeInMBmaxIternumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColrawPredictionColregParamstandardizationthresholdtolweightCol
3.3. ML 567
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 3.1.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
568 Chapter 3. API Reference
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getStandardization()Gets the value of standardization or its default value.
getThreshold()Gets the value of threshold or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
3.3. ML 569
pyspark Documentation, Release master
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThreshold(value)Sets the value of threshold.
New in version 3.0.0.
summary()Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.
New in version 3.1.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
coefficientsModel coefficients of Linear SVM Classifier.
New in version 2.2.0.
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
interceptModel intercept of Linear SVM Classifier.
New in version 2.2.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
570 Chapter 3. API Reference
pyspark Documentation, Release master
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')
threshold = Param(parent='undefined', name='threshold', doc='The threshold in binary classification applied to the linear model prediction. This threshold can be any real number, where Inf will make all predictions 0.0 and -Inf will make all predictions 1.0.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
LinearSVCSummary
class pyspark.ml.classification.LinearSVCSummary(java_obj=None)Abstraction for LinearSVC Results for a given model.
New in version 3.1.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.
continues on next page
3.3. ML 571
pyspark Documentation, Release master
Table 195 – continued from previous pagepr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
572 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
3.3. ML 573
pyspark Documentation, Release master
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
LinearSVCTrainingSummary
class pyspark.ml.classification.LinearSVCTrainingSummary(java_obj=None)Abstraction for LinearSVC Training results.
New in version 3.1.0.
574 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.pr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
3.3. ML 575
pyspark Documentation, Release master
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
576 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
3.3. ML 577
pyspark Documentation, Release master
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
LogisticRegression
class pyspark.ml.classification.LogisticRegression(*args, **kwargs)Logistic regression. This class supports multinomial logistic (softmax) and binomial logistic regression.
New in version 1.3.0.
Examples
>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> bdf = sc.parallelize([... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()>>> blor = LogisticRegression(weightCol="weight")>>> blor.getRegParam()0.0>>> blor.setRegParam(0.01)LogisticRegression...>>> blor.getRegParam()0.01>>> blor.setMaxIter(10)LogisticRegression...>>> blor.getMaxIter()10>>> blor.clear(blor.maxIter)>>> blorModel = blor.fit(bdf)>>> blorModel.setFeaturesCol("features")LogisticRegressionModel...>>> blorModel.setProbabilityCol("newProbability")LogisticRegressionModel...>>> blorModel.getProbabilityCol()'newProbability'>>> blorModel.getMaxBlockSizeInMB()0.0>>> blorModel.setThreshold(0.1)
(continues on next page)
578 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
LogisticRegressionModel...>>> blorModel.getThreshold()0.1>>> blorModel.coefficientsDenseVector([-1.080..., -0.646...])>>> blorModel.intercept3.112...>>> blorModel.evaluate(bdf).accuracy == blorModel.summary.accuracyTrue>>> data_path = "data/mllib/sample_multiclass_classification_data.txt">>> mdf = spark.read.format("libsvm").load(data_path)>>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family=→˓"multinomial")>>> mlorModel = mlor.fit(mdf)>>> mlorModel.coefficientMatrixSparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)>>> mlorModel.interceptVectorDenseVector([0.04..., -0.42..., 0.37...])>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()>>> blorModel.predict(test0.head().features)1.0>>> blorModel.predictRaw(test0.head().features)DenseVector([-3.54..., 3.54...])>>> blorModel.predictProbability(test0.head().features)DenseVector([0.028, 0.972])>>> result = blorModel.transform(test0).head()>>> result.prediction1.0>>> result.newProbabilityDenseVector([0.02..., 0.97...])>>> result.rawPredictionDenseVector([-3.54..., 3.54...])>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()>>> blorModel.transform(test1).head().prediction1.0>>> blor.setParams("vector")Traceback (most recent call last):
...TypeError: Method setParams forces keyword arguments.>>> lr_path = temp_path + "/lr">>> blor.save(lr_path)>>> lr2 = LogisticRegression.load(lr_path)>>> lr2.getRegParam()0.01>>> model_path = temp_path + "/lr_model">>> blorModel.save(model_path)>>> model2 = LogisticRegressionModel.load(model_path)>>> blorModel.coefficients[0] == model2.coefficients[0]True>>> blorModel.intercept == model2.interceptTrue>>> model2LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2>>> blorModel.transform(test0).take(1) == model2.transform(test0).take(1)True
3.3. ML 579
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getElasticNetParam() Gets the value of elasticNetParam or its defaultvalue.
getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLowerBoundsOnCoefficients() Gets the value of
lowerBoundsOnCoefficientsgetLowerBoundsOnIntercepts() Gets the value of lowerBoundsOnInterceptsgetMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Get threshold for binary classification.getThresholds() If thresholds is set, return its value.getTol() Gets the value of tol or its default value.getUpperBoundsOnCoefficients() Gets the value of
upperBoundsOnCoefficientsgetUpperBoundsOnIntercepts() Gets the value of upperBoundsOnInterceptsgetWeightCol() Gets the value of weightCol or its default value.
continues on next page
580 Chapter 3. API Reference
pyspark Documentation, Release master
Table 198 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setElasticNetParam(value) Sets the value of elasticNetParam.setFamily(value) Sets the value of family .setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setLowerBoundsOnCoefficients(value) Sets the value of
lowerBoundsOnCoefficientssetLowerBoundsOnIntercepts(value) Sets the value of lowerBoundsOnInterceptssetMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-
belCol=”label”, predictionCol=”prediction”,maxIter=100, regParam=0.0, elasticNetParam=0.0,tol=1e-6, fitIntercept=True, threshold=0.5, thresh-olds=None, probabilityCol=”probability”, rawPre-dictionCol=”rawPrediction”, standardization=True,weightCol=None, aggregationDepth=2, fam-ily=”auto”, lowerBoundsOnCoefficients=None, up-perBoundsOnCoefficients=None, lowerBoundsOn-Intercepts=None, upperBoundsOnIntercepts=None,maxBlockSizeInMB=0.0): Sets params for logisticregression.
setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setRegParam(value) Sets the value of regParam.setStandardization(value) Sets the value of standardization.setThreshold(value) Sets the value of threshold.setThresholds(value) Sets the value of thresholds.setTol(value) Sets the value of tol.setUpperBoundsOnCoefficients(value) Sets the value of
upperBoundsOnCoefficientssetUpperBoundsOnIntercepts(value) Sets the value of upperBoundsOnInterceptssetWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
3.3. ML 581
pyspark Documentation, Release master
Attributes
aggregationDepthelasticNetParamfamilyfeaturesColfitInterceptlabelCollowerBoundsOnCoefficientslowerBoundsOnInterceptsmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParamstandardizationthresholdthresholdstolupperBoundsOnCoefficientsupperBoundsOnInterceptsweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
582 Chapter 3. API Reference
pyspark Documentation, Release master
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getElasticNetParam()Gets the value of elasticNetParam or its default value.
getFamily()Gets the value of family or its default value.
New in version 2.1.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getLowerBoundsOnCoefficients()Gets the value of lowerBoundsOnCoefficients
New in version 2.3.0.
3.3. ML 583
pyspark Documentation, Release master
getLowerBoundsOnIntercepts()Gets the value of lowerBoundsOnIntercepts
New in version 2.3.0.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getStandardization()Gets the value of standardization or its default value.
getThreshold()Get threshold for binary classification.
If thresholds is set with length 2 (i.e., binary classification), this returns the equivalent threshold:1
1+𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(0)𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(1)
. Otherwise, returns threshold if set or its default value if unset.
New in version 1.4.0.
getThresholds()If thresholds is set, return its value. Otherwise, if threshold is set, return the equivalent thresholdsfor binary classification: (1-threshold, threshold). If neither are set, throw an error.
New in version 1.5.0.
getTol()Gets the value of tol or its default value.
getUpperBoundsOnCoefficients()Gets the value of upperBoundsOnCoefficients
New in version 2.3.0.
getUpperBoundsOnIntercepts()Gets the value of upperBoundsOnIntercepts
New in version 2.3.0.
getWeightCol()Gets the value of weightCol or its default value.
584 Chapter 3. API Reference
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAggregationDepth(value)Sets the value of aggregationDepth.
setElasticNetParam(value)Sets the value of elasticNetParam.
setFamily(value)Sets the value of family .
New in version 2.1.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setFitIntercept(value)Sets the value of fitIntercept.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLowerBoundsOnCoefficients(value)Sets the value of lowerBoundsOnCoefficients
New in version 2.3.0.
setLowerBoundsOnIntercepts(value)Sets the value of lowerBoundsOnIntercepts
New in version 2.3.0.
setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.
New in version 3.1.0.
setMaxIter(value)Sets the value of maxIter.
3.3. ML 585
pyspark Documentation, Release master
setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, maxIter=100,regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None,probabilityCol=”probability”, rawPredictionCol=”rawPrediction”, standardization=True, weight-Col=None, aggregationDepth=2, family=”auto”, lowerBoundsOnCoefficients=None, upperBoundsOn-Coefficients=None, lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, maxBlock-SizeInMB=0.0): Sets params for logistic regression. If the threshold and thresholds Params are both set,they must be equivalent.
New in version 1.3.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setRegParam(value)Sets the value of regParam.
setStandardization(value)Sets the value of standardization.
setThreshold(value)Sets the value of threshold. Clears value of thresholds if it has been set.
New in version 1.4.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setTol(value)Sets the value of tol.
setUpperBoundsOnCoefficients(value)Sets the value of upperBoundsOnCoefficients
New in version 2.3.0.
setUpperBoundsOnIntercepts(value)Sets the value of upperBoundsOnIntercepts
New in version 2.3.0.
setWeightCol(value)Sets the value of weightCol.
write()Returns an MLWriter instance for this ML instance.
586 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')
family = Param(parent='undefined', name='family', doc='The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
lowerBoundsOnCoefficients = Param(parent='undefined', name='lowerBoundsOnCoefficients', doc='The lower bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')
lowerBoundsOnIntercepts = Param(parent='undefined', name='lowerBoundsOnIntercepts', doc='The lower bounds on intercepts if fitting under bound constrained optimization. The bounds vector size must beequal with 1 for binomial regression, or the number oflasses for multinomial regression.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')
threshold = Param(parent='undefined', name='threshold', doc='Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p].')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
upperBoundsOnCoefficients = Param(parent='undefined', name='upperBoundsOnCoefficients', doc='The upper bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')
upperBoundsOnIntercepts = Param(parent='undefined', name='upperBoundsOnIntercepts', doc='The upper bounds on intercepts if fitting under bound constrained optimization. The bound vector size must be equal with 1 for binomial regression, or the number of classes for multinomial regression.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
LogisticRegressionModel
class pyspark.ml.classification.LogisticRegressionModel(java_model=None)Model fitted by LogisticRegression.
New in version 1.3.0.
3.3. ML 587
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getElasticNetParam() Gets the value of elasticNetParam or its defaultvalue.
getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLowerBoundsOnCoefficients() Gets the value of
lowerBoundsOnCoefficientsgetLowerBoundsOnIntercepts() Gets the value of lowerBoundsOnInterceptsgetMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Get threshold for binary classification.getThresholds() If thresholds is set, return its value.getTol() Gets the value of tol or its default value.getUpperBoundsOnCoefficients() Gets the value of
upperBoundsOnCoefficientsgetUpperBoundsOnIntercepts() Gets the value of upperBoundsOnInterceptsgetWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.continues on next page
588 Chapter 3. API Reference
pyspark Documentation, Release master
Table 200 – continued from previous pageisDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThreshold(value) Sets the value of threshold.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthcoefficientMatrix Model coefficients.coefficients Model coefficients of binomial logistic regression.elasticNetParamfamilyfeaturesColfitIntercepthasSummary Indicates whether a training summary exists for this
model instance.intercept Model intercept of binomial logistic regression.interceptVector Model intercept.labelCollowerBoundsOnCoefficientslowerBoundsOnInterceptsmaxBlockSizeInMBmaxIternumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParam
continues on next page
3.3. ML 589
pyspark Documentation, Release master
Table 201 – continued from previous pagestandardizationsummary Gets summary (e.g.thresholdthresholdstolupperBoundsOnCoefficientsupperBoundsOnInterceptsweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 2.0.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getElasticNetParam()Gets the value of elasticNetParam or its default value.
590 Chapter 3. API Reference
pyspark Documentation, Release master
getFamily()Gets the value of family or its default value.
New in version 2.1.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getLowerBoundsOnCoefficients()Gets the value of lowerBoundsOnCoefficients
New in version 2.3.0.
getLowerBoundsOnIntercepts()Gets the value of lowerBoundsOnIntercepts
New in version 2.3.0.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getStandardization()Gets the value of standardization or its default value.
getThreshold()Get threshold for binary classification.
If thresholds is set with length 2 (i.e., binary classification), this returns the equivalent threshold:1
1+𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(0)𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(1)
. Otherwise, returns threshold if set or its default value if unset.
New in version 1.4.0.
getThresholds()If thresholds is set, return its value. Otherwise, if threshold is set, return the equivalent thresholdsfor binary classification: (1-threshold, threshold). If neither are set, throw an error.
3.3. ML 591
pyspark Documentation, Release master
New in version 1.5.0.
getTol()Gets the value of tol or its default value.
getUpperBoundsOnCoefficients()Gets the value of upperBoundsOnCoefficients
New in version 2.3.0.
getUpperBoundsOnIntercepts()Gets the value of upperBoundsOnIntercepts
New in version 2.3.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
592 Chapter 3. API Reference
pyspark Documentation, Release master
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThreshold(value)Sets the value of threshold. Clears value of thresholds if it has been set.
New in version 1.4.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
coefficientMatrixModel coefficients.
New in version 2.1.0.
coefficientsModel coefficients of binomial logistic regression. An exception is thrown in the case of multinomiallogistic regression.
New in version 2.0.0.
elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')
family = Param(parent='undefined', name='family', doc='The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
3.3. ML 593
pyspark Documentation, Release master
interceptModel intercept of binomial logistic regression. An exception is thrown in the case of multinomial logisticregression.
New in version 1.4.0.
interceptVectorModel intercept.
New in version 2.1.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
lowerBoundsOnCoefficients = Param(parent='undefined', name='lowerBoundsOnCoefficients', doc='The lower bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')
lowerBoundsOnIntercepts = Param(parent='undefined', name='lowerBoundsOnIntercepts', doc='The lower bounds on intercepts if fitting under bound constrained optimization. The bounds vector size must beequal with 1 for binomial regression, or the number oflasses for multinomial regression.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')
summaryGets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.
New in version 2.0.0.
threshold = Param(parent='undefined', name='threshold', doc='Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p].')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
upperBoundsOnCoefficients = Param(parent='undefined', name='upperBoundsOnCoefficients', doc='The upper bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')
upperBoundsOnIntercepts = Param(parent='undefined', name='upperBoundsOnIntercepts', doc='The upper bounds on intercepts if fitting under bound constrained optimization. The bound vector size must be equal with 1 for binomial regression, or the number of classes for multinomial regression.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
594 Chapter 3. API Reference
pyspark Documentation, Release master
LogisticRegressionSummary
class pyspark.ml.classification.LogisticRegressionSummary(java_obj=None)Abstraction for Logistic Regression Results for a given model.
New in version 2.0.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of
each instance as a vector.labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.probabilityCol Field in “predictions” which gives the probability of
each class as a vector.recallByLabel Returns recall for each label (category).truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
3.3. ML 595
pyspark Documentation, Release master
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
featuresColField in “predictions” which gives the features of each instance as a vector.
New in version 2.0.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
probabilityColField in “predictions” which gives the probability of each class as a vector.
New in version 2.0.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
596 Chapter 3. API Reference
pyspark Documentation, Release master
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
LogisticRegressionTrainingSummary
class pyspark.ml.classification.LogisticRegressionTrainingSummary(java_obj=None)Abstraction for multinomial Logistic Regression Training results.
New in version 2.0.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of
each instance as a vector.labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.continues on next page
3.3. ML 597
pyspark Documentation, Release master
Table 205 – continued from previous pageprobabilityCol Field in “predictions” which gives the probability of
each class as a vector.recallByLabel Returns recall for each label (category).totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
featuresColField in “predictions” which gives the features of each instance as a vector.
New in version 2.0.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
598 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
probabilityColField in “predictions” which gives the probability of each class as a vector.
New in version 2.0.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
3.3. ML 599
pyspark Documentation, Release master
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
BinaryLogisticRegressionSummary
class pyspark.ml.classification.BinaryLogisticRegressionSummary(java_obj=None)Binary Logistic regression results for a given model.
New in version 2.0.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of
each instance as a vector.labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.pr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.probabilityCol Field in “predictions” which gives the probability of
each class as a vector.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
continues on next page
600 Chapter 3. API Reference
pyspark Documentation, Release master
Table 207 – continued from previous pagescoreCol Field in “predictions” which gives the probability or
raw prediction of each class as a vector.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
featuresColField in “predictions” which gives the features of each instance as a vector.
New in version 2.0.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
3.3. ML 601
pyspark Documentation, Release master
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
probabilityColField in “predictions” which gives the probability of each class as a vector.
New in version 2.0.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
602 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
BinaryLogisticRegressionTrainingSummary
class pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary(java_obj=None)Binary Logistic regression training results for a given model.
New in version 2.0.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
3.3. ML 603
pyspark Documentation, Release master
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of
each instance as a vector.labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.pr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.probabilityCol Field in “predictions” which gives the probability of
each class as a vector.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
604 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
featuresColField in “predictions” which gives the features of each instance as a vector.
New in version 2.0.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
3.3. ML 605
pyspark Documentation, Release master
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
probabilityColField in “predictions” which gives the probability of each class as a vector.
New in version 2.0.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
606 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
DecisionTreeClassifier
class pyspark.ml.classification.DecisionTreeClassifier(*args, **kwargs)Decision tree learning algorithm for classification. It supports both binary and multiclass labels, as well as bothcontinuous and categorical features.
New in version 1.4.0.
3.3. ML 607
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.feature import StringIndexer>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model = stringIndexer.fit(df)>>> td = si_model.transform(df)>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")>>> model = dt.fit(td)>>> model.getLabelCol()'indexed'>>> model.setFeaturesCol("features")DecisionTreeClassificationModel...>>> model.numNodes3>>> model.depth1>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> model.numFeatures1>>> model.numClasses2>>> print(model.toDebugString)DecisionTreeClassificationModel...depth=1, numNodes=3...>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictRaw(test0.head().features)DenseVector([1.0, 0.0])>>> model.predictProbability(test0.head().features)DenseVector([1.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.probabilityDenseVector([1.0, 0.0])>>> result.rawPredictionDenseVector([1.0, 0.0])>>> result.leafId0.0>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> dtc_path = temp_path + "/dtc">>> dt.save(dtc_path)>>> dt2 = DecisionTreeClassifier.load(dtc_path)>>> dt2.getMaxDepth()2>>> model_path = temp_path + "/dtc_model">>> model.save(model_path)>>> model2 = DecisionTreeClassificationModel.load(model_path)>>> model.featureImportances == model2.featureImportances
(continues on next page)
608 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
True>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> df3 = spark.createDataFrame([... (1.0, 0.2, Vectors.dense(1.0)),... (1.0, 0.8, Vectors.dense(1.0)),... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])>>> si3 = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model3 = si3.fit(df3)>>> td3 = si_model3.transform(df3)>>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed→˓")>>> model3 = dt3.fit(td3)>>> print(model3.toDebugString)DecisionTreeClassificationModel...depth=1, numNodes=3...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.
continues on next page
3.3. ML 609
pyspark Documentation, Release master
Table 210 – continued from previous pagegetMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of
minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for the DecisionTreeClassifier.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setThresholds(value) Sets the value of thresholds.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
610 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
cacheNodeIdscheckpointIntervalfeaturesColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsupportedImpuritiesthresholdsweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
3.3. ML 611
pyspark Documentation, Release master
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.6.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
612 Chapter 3. API Reference
pyspark Documentation, Release master
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getThresholds()Gets the value of thresholds or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCacheNodeIds(value)Sets the value of cacheNodeIds.
setCheckpointInterval(value)Sets the value of checkpointInterval.
3.3. ML 613
pyspark Documentation, Release master
New in version 1.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setImpurity(value)Sets the value of impurity .
New in version 1.4.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setMaxBins(value)Sets the value of maxBins.
setMaxDepth(value)Sets the value of maxDepth.
setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.
setMinInfoGain(value)Sets the value of minInfoGain.
setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.
setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.
New in version 3.0.0.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5,maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256,cacheNodeIds=False, checkpointInterval=10, impurity="gini", seed=None, weight-Col=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for the DecisionTreeClassifier.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
614 Chapter 3. API Reference
pyspark Documentation, Release master
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
supportedImpurities = ['entropy', 'gini']
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 615
pyspark Documentation, Release master
DecisionTreeClassificationModel
class pyspark.ml.classification.DecisionTreeClassificationModel(java_model=None)Model fitted by DecisionTreeClassifier.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.
continues on next page
616 Chapter 3. API Reference
pyspark Documentation, Release master
Table 212 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the
feature vector.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
cacheNodeIdscheckpointIntervaldepth Return depth of the decision tree.featureImportances Estimate of the importance of each feature.featuresColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.numNodes Return number of nodes of the decision tree.params Returns all params ordered by name.predictionCol
continues on next page
3.3. ML 617
pyspark Documentation, Release master
Table 213 – continued from previous pageprobabilityColrawPredictionColseedsupportedImpuritiesthresholdstoDebugString Full description of model.weightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.6.0.
618 Chapter 3. API Reference
pyspark Documentation, Release master
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getThresholds()Gets the value of thresholds or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 619
pyspark Documentation, Release master
predict(value)Predict label for the given features.
New in version 3.0.0.
predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.
New in version 3.0.0.
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
620 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
depthReturn depth of the decision tree.
New in version 1.5.0.
featureImportancesEstimate of the importance of each feature.
This generalizes the idea of “Gini” importance to other losses, following the explanation of Gini im-portance from “Random Forests” documentation by Leo Breiman and Adele Cutler, and following theimplementation from scikit-learn.
This feature importance is calculated as follows:
• importance(feature j) = sum (over nodes which split on feature j) of the gain, where gain is scaledby the number of instances passing through node
• Normalize importances for tree to sum to 1.
New in version 2.0.0.
Notes
Feature importance for single decision trees can have high variance due to correlated predictor variables.Consider using a RandomForestClassifier to determine feature importance instead.
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
3.3. ML 621
pyspark Documentation, Release master
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
numNodesReturn number of nodes of the decision tree.
New in version 1.5.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
supportedImpurities = ['entropy', 'gini']
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
toDebugStringFull description of model.
New in version 2.0.0.
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
GBTClassifier
class pyspark.ml.classification.GBTClassifier(*args, **kwargs)Gradient-Boosted Trees (GBTs) learning algorithm for classification. It supports binary labels, as well as bothcontinuous and categorical features.
New in version 1.4.0.
Notes
Multiclass labels are not currently supported.
The implementation is based upon: J.H. Friedman. “Stochastic Gradient Boosting.” 1999.
Gradient Boosting vs. TreeBoost:
• This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
• Both algorithms learn tree ensembles by minimizing loss functions.
• TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function,whereas the original gradient boosting method does not.
• We expect to implement TreeBoost in the future: SPARK-4240
622 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.feature import StringIndexer>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model = stringIndexer.fit(df)>>> td = si_model.transform(df)>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,... leafCol="leafId")>>> gbt.setMaxIter(5)GBTClassifier...>>> gbt.setMinWeightFractionPerNode(0.049)GBTClassifier...>>> gbt.getMaxIter()5>>> gbt.getFeatureSubsetStrategy()'all'>>> model = gbt.fit(td)>>> model.getLabelCol()'indexed'>>> model.setFeaturesCol("features")GBTClassificationModel...>>> model.setThresholds([0.3, 0.7])GBTClassificationModel...>>> model.getThresholds()[0.3, 0.7]>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictRaw(test0.head().features)DenseVector([1.1697, -1.1697])>>> model.predictProbability(test0.head().features)DenseVector([0.9121, 0.0879])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.leafIdDenseVector([0.0, 0.0, 0.0, 0.0, 0.0])>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> model.totalNumNodes15>>> print(model.toDebugString)GBTClassificationModel...numTrees=5...>>> gbtc_path = temp_path + "gbtc">>> gbt.save(gbtc_path)>>> gbt2 = GBTClassifier.load(gbtc_path)
(continues on next page)
3.3. ML 623
pyspark Documentation, Release master
(continued from previous page)
>>> gbt2.getMaxDepth()2>>> model_path = temp_path + "gbtc_model">>> model.save(model_path)>>> model2 = GBTClassificationModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.treeWeights == model2.treeWeightsTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.trees[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],... ["indexed", "features"])>>> model.evaluateEachIteration(validation)[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]>>> model.numClasses2>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")>>> gbt.getValidationIndicatorCol()'validationIndicator'>>> gbt.getValidationTol()0.01
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.continues on next page
624 Chapter 3. API Reference
pyspark Documentation, Release master
Table 214 – continued from previous pagegetFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getThresholds() Gets the value of thresholds or its default value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default
value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setLossType(value) Sets the value of lossType.
continues on next page
3.3. ML 625
pyspark Documentation, Release master
Table 214 – continued from previous pagesetMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxIter(value) Sets the value of maxIter.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of
minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for Gradient Boosted Tree Classifica-
tion.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setStepSize(value) Sets the value of stepSize.setSubsamplingRate(value) Sets the value of subsamplingRate.setThresholds(value) Sets the value of thresholds.setValidationIndicatorCol(value) Sets the value of validationIndicatorCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
cacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpuritylabelColleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypesthresholds
continues on next page
626 Chapter 3. API Reference
pyspark Documentation, Release master
Table 215 – continued from previous pagevalidationIndicatorColvalidationTolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
3.3. ML 627
pyspark Documentation, Release master
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getLossType()Gets the value of lossType or its default value.
New in version 1.4.0.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
628 Chapter 3. API Reference
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getStepSize()Gets the value of stepSize or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getThresholds()Gets the value of thresholds or its default value.
getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.
getValidationTol()Gets the value of validationTol or its default value.
New in version 3.0.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
3.3. ML 629
pyspark Documentation, Release master
setCacheNodeIds(value)Sets the value of cacheNodeIds.
setCheckpointInterval(value)Sets the value of checkpointInterval.
New in version 1.4.0.
setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setImpurity(value)Sets the value of impurity .
New in version 1.4.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setLossType(value)Sets the value of lossType.
New in version 1.4.0.
setMaxBins(value)Sets the value of maxBins.
setMaxDepth(value)Sets the value of maxDepth.
setMaxIter(value)Sets the value of maxIter.
New in version 1.4.0.
setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.
setMinInfoGain(value)Sets the value of minInfoGain.
setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.
setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.
New in version 3.0.0.
630 Chapter 3. API Reference
pyspark Documentation, Release master
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance", fea-tureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None, leafCol="",minWeightFractionPerNode=0.0, weightCol=None)
Sets params for Gradient Boosted Tree Classification.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
New in version 1.4.0.
setStepSize(value)Sets the value of stepSize.
New in version 1.4.0.
setSubsamplingRate(value)Sets the value of subsamplingRate.
New in version 1.4.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setValidationIndicatorCol(value)Sets the value of validationIndicatorCol.
New in version 3.0.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
3.3. ML 631
pyspark Documentation, Release master
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: logistic')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['variance']
supportedLossTypes = ['logistic']
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')
validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
632 Chapter 3. API Reference
pyspark Documentation, Release master
GBTClassificationModel
class pyspark.ml.classification.GBTClassificationModel(java_model=None)Model fitted by GBTClassifier.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluateEachIteration(dataset) Method to compute error or loss for every iterationof gradient boosting.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.
continues on next page
3.3. ML 633
pyspark Documentation, Release master
Table 216 – continued from previous pagegetRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getThresholds() Gets the value of thresholds or its default value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default
value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the
feature vector.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
cacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.impuritylabelCol
continues on next page
634 Chapter 3. API Reference
pyspark Documentation, Release master
Table 217 – continued from previous pageleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypesthresholdstoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the
ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.validationIndicatorColvalidationTolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluateEachIteration(dataset)Method to compute error or loss for every iteration of gradient boosting.
New in version 2.4.0.
3.3. ML 635
pyspark Documentation, Release master
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getLossType()Gets the value of lossType or its default value.
New in version 1.4.0.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
636 Chapter 3. API Reference
pyspark Documentation, Release master
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getStepSize()Gets the value of stepSize or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getThresholds()Gets the value of thresholds or its default value.
getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.
getValidationTol()Gets the value of validationTol or its default value.
New in version 3.0.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 637
pyspark Documentation, Release master
predict(value)Predict label for the given features.
New in version 3.0.0.
predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.
New in version 3.0.0.
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
638 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureImportancesEstimate of the importance of each feature.
Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.
New in version 2.0.0.
See also:
DecisionTreeClassificationModel.featureImportances
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
getNumTreesNumber of trees in ensemble.
New in version 2.0.0.
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: logistic')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
3.3. ML 639
pyspark Documentation, Release master
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['variance']
supportedLossTypes = ['logistic']
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
toDebugStringFull description of model.
New in version 2.0.0.
totalNumNodesTotal number of nodes, summed over all trees in the ensemble.
New in version 2.0.0.
treeWeightsReturn the weights for each tree
New in version 1.5.0.
treesTrees in this ensemble. Warning: These have null parent Estimators.
New in version 2.0.0.
validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')
validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
640 Chapter 3. API Reference
pyspark Documentation, Release master
RandomForestClassifier
class pyspark.ml.classification.RandomForestClassifier(*args, **kwargs)Random Forest learning algorithm for classification. It supports both binary and multiclass labels, as well asboth continuous and categorical features.
New in version 1.4.0.
Examples
>>> import numpy>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.feature import StringIndexer>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model = stringIndexer.fit(df)>>> td = si_model.transform(df)>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed",→˓seed=42,... leafCol="leafId")>>> rf.getMinWeightFractionPerNode()0.0>>> model = rf.fit(td)>>> model.getLabelCol()'indexed'>>> model.setFeaturesCol("features")RandomForestClassificationModel...>>> model.setRawPredictionCol("newRawPrediction")RandomForestClassificationModel...>>> model.getBootstrap()True>>> model.getRawPredictionCol()'newRawPrediction'>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictRaw(test0.head().features)DenseVector([2.0, 0.0])>>> model.predictProbability(test0.head().features)DenseVector([1.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> numpy.argmax(result.probability)0>>> numpy.argmax(result.newRawPrediction)0>>> result.leafIdDenseVector([0.0, 0.0, 0.0])>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"]) (continues on next page)
3.3. ML 641
pyspark Documentation, Release master
(continued from previous page)
>>> model.transform(test1).head().prediction1.0>>> model.trees[DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]>>> rfc_path = temp_path + "/rfc">>> rf.save(rfc_path)>>> rf2 = RandomForestClassifier.load(rfc_path)>>> rf2.getNumTrees()3>>> model_path = temp_path + "/rfc_model">>> model.save(model_path)>>> model2 = RandomForestClassificationModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.
continues on next page
642 Chapter 3. API Reference
pyspark Documentation, Release master
Table 218 – continued from previous pagegetMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getNumTrees() Gets the value of numTrees or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBootstrap(value) Sets the value of bootstrap.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of
minWeightFractionPerNode.setNumTrees(value) Sets the value of numTrees.setParams(self[, featuresCol, labelCol, . . . ]) Sets params for linear classification.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.
continues on next page
3.3. ML 643
pyspark Documentation, Release master
Table 218 – continued from previous pagesetRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setSubsamplingRate(value) Sets the value of subsamplingRate.setThresholds(value) Sets the value of thresholds.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
bootstrapcacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumTreesparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiesthresholdsweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
644 Chapter 3. API Reference
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getBootstrap()Gets the value of bootstrap or its default value.
New in version 3.0.0.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
3.3. ML 645
pyspark Documentation, Release master
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.6.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getNumTrees()Gets the value of numTrees or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
646 Chapter 3. API Reference
pyspark Documentation, Release master
getThresholds()Gets the value of thresholds or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBootstrap(value)Sets the value of bootstrap.
New in version 3.0.0.
setCacheNodeIds(value)Sets the value of cacheNodeIds.
setCheckpointInterval(value)Sets the value of checkpointInterval.
setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setImpurity(value)Sets the value of impurity .
New in version 1.4.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
3.3. ML 647
pyspark Documentation, Release master
setMaxBins(value)Sets the value of maxBins.
setMaxDepth(value)Sets the value of maxDepth.
setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.
setMinInfoGain(value)Sets the value of minInfoGain.
setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.
setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.
New in version 3.0.0.
setNumTrees(value)Sets the value of numTrees.
New in version 1.4.0.
setParams(self, featuresCol='features', labelCol='label', predictionCol='prediction', probability-Col='probability', rawPredictionCol='rawPrediction', maxDepth=5, maxBins=32, minIn-stancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False,checkpointInterval=10, seed=None, impurity='gini', numTrees=20, featureSubsetStrat-egy='auto', subsamplingRate=1.0, leafCol='', minWeightFractionPerNode=0.0, weight-Col=None, bootstrap=True)
Sets params for linear classification.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
setSubsamplingRate(value)Sets the value of subsamplingRate.
New in version 1.4.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setWeightCol(value)Sets the value of weightCol.
648 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['entropy', 'gini']
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 649
pyspark Documentation, Release master
RandomForestClassificationModel
class pyspark.ml.classification.RandomForestClassificationModel(java_model=None)Model fitted by RandomForestClassifier.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.continues on next page
650 Chapter 3. API Reference
pyspark Documentation, Release master
Table 220 – continued from previous pagegetSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the
feature vector.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
bootstrapcacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.hasSummary Indicates whether a training summary exists for this
model instance.impuritylabelColleafColmaxBinsmaxDepth
continues on next page
3.3. ML 651
pyspark Documentation, Release master
Table 221 – continued from previous pagemaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.numTreesparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsubsamplingRatesummary Gets summary (e.g.supportedFeatureSubsetStrategiessupportedImpuritiesthresholdstoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the
ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.weightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 3.1.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
652 Chapter 3. API Reference
pyspark Documentation, Release master
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBootstrap()Gets the value of bootstrap or its default value.
New in version 3.0.0.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.6.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
3.3. ML 653
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getThresholds()Gets the value of thresholds or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.
New in version 3.0.0.
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
654 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 655
pyspark Documentation, Release master
Attributes Documentation
bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureImportancesEstimate of the importance of each feature.
Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.
New in version 2.0.0.
See also:
DecisionTreeClassificationModel.featureImportances
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
getNumTreesNumber of trees in ensemble.
New in version 2.0.0.
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')
656 Chapter 3. API Reference
pyspark Documentation, Release master
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
summaryGets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.
New in version 3.1.0.
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['entropy', 'gini']
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
toDebugStringFull description of model.
New in version 2.0.0.
totalNumNodesTotal number of nodes, summed over all trees in the ensemble.
New in version 2.0.0.
treeWeightsReturn the weights for each tree
New in version 1.5.0.
treesTrees in this ensemble. Warning: These have null parent Estimators.
New in version 2.0.0.
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
RandomForestClassificationSummary
class pyspark.ml.classification.RandomForestClassificationSummary(java_obj=None)Abstraction for RandomForestClassification Results for a given model.
New in version 3.1.0.
3.3. ML 657
pyspark Documentation, Release master
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
658 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
3.3. ML 659
pyspark Documentation, Release master
RandomForestClassificationTrainingSummary
class pyspark.ml.classification.RandomForestClassificationTrainingSummary(java_obj=None)Abstraction for RandomForestClassificationTraining Training results.
New in version 3.1.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
660 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
3.3. ML 661
pyspark Documentation, Release master
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
BinaryRandomForestClassificationSummary
class pyspark.ml.classification.BinaryRandomForestClassificationSummary(java_obj=None)BinaryRandomForestClassification results for a given model.
New in version 3.1.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.
continues on next page
662 Chapter 3. API Reference
pyspark Documentation, Release master
Table 227 – continued from previous pagepr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
3.3. ML 663
pyspark Documentation, Release master
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
664 Chapter 3. API Reference
pyspark Documentation, Release master
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
BinaryRandomForestClassificationTrainingSummary
class pyspark.ml.classification.BinaryRandomForestClassificationTrainingSummary(java_obj=None)BinaryRandomForestClassification training results for a given model.
New in version 3.1.0.
3.3. ML 665
pyspark Documentation, Release master
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.pr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
666 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
3.3. ML 667
pyspark Documentation, Release master
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
668 Chapter 3. API Reference
pyspark Documentation, Release master
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
NaiveBayes
class pyspark.ml.classification.NaiveBayes(*args, **kwargs)Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB can handle finitelysupported discrete data. For example, by converting documents into TF-IDF vectors, it can be used for documentclassification. By making every vector a binary (0/1) data, it can also be used as Bernoulli NB.
The input feature values for Multinomial NB and Bernoulli NB must be nonnegative. Since 3.0.0, it supportsComplement NB which is an adaptation of the Multinomial NB. Specifically, Complement NB uses statisticsfrom the complement of each class to compute the model’s coefficients. The inventors of Complement NBshow empirically that the parameter estimates for CNB are more stable than those for Multinomial NB. LikeMultinomial NB, the input feature values for Complement NB must be nonnegative. Since 3.0.0, it also supportsGaussian NB. which can handle continuous data.
New in version 1.5.0.
Examples
>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")>>> model = nb.fit(df)>>> model.setFeaturesCol("features")NaiveBayesModel...>>> model.getSmoothing()1.0>>> model.piDenseVector([-0.81..., -0.58...])>>> model.thetaDenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)>>> model.sigmaDenseMatrix(0, 0, [...], ...)>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()>>> model.predict(test0.head().features)
(continues on next page)
3.3. ML 669
pyspark Documentation, Release master
(continued from previous page)
1.0>>> model.predictRaw(test0.head().features)DenseVector([-1.72..., -0.99...])>>> model.predictProbability(test0.head().features)DenseVector([0.32..., 0.67...])>>> result = model.transform(test0).head()>>> result.prediction1.0>>> result.probabilityDenseVector([0.32..., 0.67...])>>> result.rawPredictionDenseVector([-1.72..., -0.99...])>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()>>> model.transform(test1).head().prediction1.0>>> nb_path = temp_path + "/nb">>> nb.save(nb_path)>>> nb2 = NaiveBayes.load(nb_path)>>> nb2.getSmoothing()1.0>>> model_path = temp_path + "/nb_model">>> model.save(model_path)>>> model2 = NaiveBayesModel.load(model_path)>>> model.pi == model2.piTrue>>> model.theta == model2.thetaTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> nb = nb.setThresholds([0.01, 10.00])>>> model3 = nb.fit(df)>>> result = model3.transform(test0).head()>>> result.prediction0.0>>> nb3 = NaiveBayes().setModelType("gaussian")>>> model4 = nb3.fit(df)>>> model4.getModelType()'gaussian'>>> model4.sigmaDenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)>>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")>>> model5 = nb5.fit(df)>>> model5.getModelType()'complement'>>> model5.thetaDenseMatrix(2, 2, [...], 1)>>> model5.sigmaDenseMatrix(0, 0, [...], ...)
670 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getModelType() Gets the value of modelType or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSmoothing() Gets the value of smoothing or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLabelCol(value) Sets the value of labelCol.setModelType(value) Sets the value of modelType.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for Naive Bayes.setPredictionCol(value) Sets the value of predictionCol.
continues on next page
3.3. ML 671
pyspark Documentation, Release master
Table 230 – continued from previous pagesetProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSmoothing(value) Sets the value of smoothing.setThresholds(value) Sets the value of thresholds.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
featuresCollabelColmodelTypeparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColsmoothingthresholdsweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
672 Chapter 3. API Reference
pyspark Documentation, Release master
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFeaturesCol()Gets the value of featuresCol or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getModelType()Gets the value of modelType or its default value.
New in version 1.5.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSmoothing()Gets the value of smoothing or its default value.
3.3. ML 673
pyspark Documentation, Release master
New in version 1.5.0.
getThresholds()Gets the value of thresholds or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setModelType(value)Sets the value of modelType.
New in version 1.5.0.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", prob-abilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, model-Type="multinomial", thresholds=None, weightCol=None)
Sets params for Naive Bayes.
New in version 1.5.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
674 Chapter 3. API Reference
pyspark Documentation, Release master
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setSmoothing(value)Sets the value of smoothing.
New in version 1.5.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setWeightCol(value)Sets the value of weightCol.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
modelType = Param(parent='undefined', name='modelType', doc='The model type which is a string (case-sensitive). Supported options: multinomial (default), bernoulli and gaussian.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
smoothing = Param(parent='undefined', name='smoothing', doc='The smoothing parameter, should be >= 0, default is 1.0')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
NaiveBayesModel
class pyspark.ml.classification.NaiveBayesModel(java_model=None)Model fitted by NaiveBayes.
New in version 1.5.0.
3.3. ML 675
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getModelType() Gets the value of modelType or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSmoothing() Gets the value of smoothing or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.
continues on next page
676 Chapter 3. API Reference
pyspark Documentation, Release master
Table 232 – continued from previous pagetransform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
featuresCollabelColmodelTypenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.pi log of class priors.predictionColprobabilityColrawPredictionColsigma variance of each feature.smoothingtheta log of class conditional probabilities.thresholdsweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
3.3. ML 677
pyspark Documentation, Release master
extra [dict, optional] extra param values
Returns
dict merged param map
getFeaturesCol()Gets the value of featuresCol or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getModelType()Gets the value of modelType or its default value.
New in version 1.5.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSmoothing()Gets the value of smoothing or its default value.
New in version 1.5.0.
getThresholds()Gets the value of thresholds or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
678 Chapter 3. API Reference
pyspark Documentation, Release master
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 679
pyspark Documentation, Release master
Attributes Documentation
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
modelType = Param(parent='undefined', name='modelType', doc='The model type which is a string (case-sensitive). Supported options: multinomial (default), bernoulli and gaussian.')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
pilog of class priors.
New in version 2.0.0.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
sigmavariance of each feature.
New in version 3.0.0.
smoothing = Param(parent='undefined', name='smoothing', doc='The smoothing parameter, should be >= 0, default is 1.0')
thetalog of class conditional probabilities.
New in version 2.0.0.
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
MultilayerPerceptronClassifier
class pyspark.ml.classification.MultilayerPerceptronClassifier(*args,**kwargs)
Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layerhas softmax. Number of inputs has to be equal to the size of feature vectors. Number of outputs has to be equalto the total number of labels.
New in version 1.6.0.
680 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (0.0, Vectors.dense([0.0, 0.0])),... (1.0, Vectors.dense([0.0, 1.0])),... (1.0, Vectors.dense([1.0, 0.0])),... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])>>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)>>> mlp.setMaxIter(100)MultilayerPerceptronClassifier...>>> mlp.getMaxIter()100>>> mlp.getBlockSize()128>>> mlp.setBlockSize(1)MultilayerPerceptronClassifier...>>> mlp.getBlockSize()1>>> model = mlp.fit(df)>>> model.setFeaturesCol("features")MultilayerPerceptronClassificationModel...>>> model.getMaxIter()100>>> model.getLayers()[2, 2, 2]>>> model.weights.size12>>> testDF = spark.createDataFrame([... (Vectors.dense([1.0, 0.0]),),... (Vectors.dense([0.0, 0.0]),)], ["features"])>>> model.predict(testDF.head().features)1.0>>> model.predictRaw(testDF.head().features)DenseVector([-16.208, 16.344])>>> model.predictProbability(testDF.head().features)DenseVector([0.0, 1.0])>>> model.transform(testDF).select("features", "prediction").show()+---------+----------+| features|prediction|+---------+----------+|[1.0,0.0]| 1.0||[0.0,0.0]| 0.0|+---------+----------+...>>> mlp_path = temp_path + "/mlp">>> mlp.save(mlp_path)>>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)>>> mlp2.getBlockSize()1>>> model_path = temp_path + "/mlp_model">>> model.save(model_path)>>> model2 = MultilayerPerceptronClassificationModel.load(model_path)>>> model.getLayers() == model2.getLayers()True>>> model.weights == model2.weightsTrue
(continues on next page)
3.3. ML 681
pyspark Documentation, Release master
(continued from previous page)
>>> model.transform(testDF).take(1) == model2.transform(testDF).take(1)True>>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))>>> model3 = mlp2.fit(df)>>> model3.weights != model2.weightsTrue>>> model3.getLayers() == model.getLayers()True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getBlockSize() Gets the value of blockSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getInitialWeights() Gets the value of initialWeights or its default value.getLabelCol() Gets the value of labelCol or its default value.getLayers() Gets the value of layers or its default value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.
continues on next page
682 Chapter 3. API Reference
pyspark Documentation, Release master
Table 234 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBlockSize(value) Sets the value of blockSize.setFeaturesCol(value) Sets the value of featuresCol.setInitialWeights(value) Sets the value of initialWeights.setLabelCol(value) Sets the value of labelCol.setLayers(value) Sets the value of layers.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-
belCol=”label”, predictionCol=”prediction”,maxIter=100, tol=1e-6, seed=None, layers=None,blockSize=128, stepSize=0.03, solver=”l-bfgs”,initialWeights=None, probabilityCol=”probability”,rawPredictionCol=”rawPrediction”): Sets paramsfor MultilayerPerceptronClassifier.
setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setSolver(value) Sets the value of solver.setStepSize(value) Sets the value of stepSize.setThresholds(value) Sets the value of thresholds.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.
Attributes
blockSizefeaturesColinitialWeightslabelCollayersmaxIterparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsolverstepSize
continues on next page
3.3. ML 683
pyspark Documentation, Release master
Table 235 – continued from previous pagethresholdstol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
684 Chapter 3. API Reference
pyspark Documentation, Release master
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getBlockSize()Gets the value of blockSize or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getInitialWeights()Gets the value of initialWeights or its default value.
New in version 2.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getLayers()Gets the value of layers or its default value.
New in version 1.6.0.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getSolver()Gets the value of solver or its default value.
getStepSize()Gets the value of stepSize or its default value.
getThresholds()Gets the value of thresholds or its default value.
getTol()Gets the value of tol or its default value.
3.3. ML 685
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBlockSize(value)Sets the value of blockSize.
New in version 1.6.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setInitialWeights(value)Sets the value of initialWeights.
New in version 2.0.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLayers(value)Sets the value of layers.
New in version 1.6.0.
setMaxIter(value)Sets the value of maxIter.
setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, maxIter=100,tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver=”l-bfgs”, initialWeights=None,probabilityCol=”probability”, rawPredictionCol=”rawPrediction”): Sets params for MultilayerPercep-tronClassifier.
New in version 1.6.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
686 Chapter 3. API Reference
pyspark Documentation, Release master
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
setSolver(value)Sets the value of solver.
setStepSize(value)Sets the value of stepSize.
New in version 2.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setTol(value)Sets the value of tol.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
initialWeights = Param(parent='undefined', name='initialWeights', doc='The initial weights of the model.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
layers = Param(parent='undefined', name='layers', doc='Sizes of layers from input layer to output layer E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 neurons and output layer of 10 neurons.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: l-bfgs, gd.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
3.3. ML 687
pyspark Documentation, Release master
MultilayerPerceptronClassificationModel
class pyspark.ml.classification.MultilayerPerceptronClassificationModel(java_model=None)Model fitted by MultilayerPerceptronClassifier.
New in version 1.6.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBlockSize() Gets the value of blockSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getInitialWeights() Gets the value of initialWeights or its default value.getLabelCol() Gets the value of labelCol or its default value.getLayers() Gets the value of layers or its default value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).continues on next page
688 Chapter 3. API Reference
pyspark Documentation, Release master
Table 236 – continued from previous pagepredict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.summary() Gets summary (e.g.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
blockSizefeaturesColhasSummary Indicates whether a training summary exists for this
model instance.initialWeightslabelCollayersmaxIternumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsolverstepSizethresholdstolweights the weights of layers.
3.3. ML 689
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 3.1.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBlockSize()Gets the value of blockSize or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getInitialWeights()Gets the value of initialWeights or its default value.
New in version 2.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getLayers()Gets the value of layers or its default value.
New in version 1.6.0.
690 Chapter 3. API Reference
pyspark Documentation, Release master
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getSolver()Gets the value of solver or its default value.
getStepSize()Gets the value of stepSize or its default value.
getThresholds()Gets the value of thresholds or its default value.
getTol()Gets the value of tol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
3.3. ML 691
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
summary()Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.
New in version 3.1.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
692 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
initialWeights = Param(parent='undefined', name='initialWeights', doc='The initial weights of the model.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
layers = Param(parent='undefined', name='layers', doc='Sizes of layers from input layer to output layer E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 neurons and output layer of 10 neurons.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: l-bfgs, gd.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightsthe weights of layers.
New in version 2.0.0.
MultilayerPerceptronClassificationSummary
class pyspark.ml.classification.MultilayerPerceptronClassificationSummary(java_obj=None)Abstraction for MultilayerPerceptronClassifier Results for a given model.
New in version 3.1.0.
3.3. ML 693
pyspark Documentation, Release master
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
694 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
3.3. ML 695
pyspark Documentation, Release master
MultilayerPerceptronClassificationTrainingSummary
class pyspark.ml.classification.MultilayerPerceptronClassificationTrainingSummary(java_obj=None)Abstraction for MultilayerPerceptronClassifier Training results.
New in version 3.1.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
696 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
3.3. ML 697
pyspark Documentation, Release master
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
OneVsRest
class pyspark.ml.classification.OneVsRest(*args, **kwargs)Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strat-egy. For a multiclass classification with k classes, train k models (one per class). Each example is scored againstall k models and the model with highest score is picked to label the example.
New in version 2.0.0.
Examples
>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> data_path = "data/mllib/sample_multiclass_classification_data.txt">>> df = spark.read.format("libsvm").load(data_path)>>> lr = LogisticRegression(regParam=0.01)>>> ovr = OneVsRest(classifier=lr)>>> ovr.getRawPredictionCol()'rawPrediction'>>> ovr.setPredictionCol("newPrediction")OneVsRest...>>> model = ovr.fit(df)>>> model.models[0].coefficientsDenseVector([0.5..., -1.0..., 3.4..., 4.2...])>>> model.models[1].coefficientsDenseVector([-2.1..., 3.1..., -2.6..., -2.3...])>>> model.models[2].coefficientsDenseVector([0.3..., -3.4..., 1.0..., -1.1...])>>> [x.intercept for x in model.models]
(continues on next page)
698 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
[-2.7..., -2.5..., -1.3...]>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).→˓toDF()>>> model.transform(test0).head().newPrediction0.0>>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()>>> model.transform(test1).head().newPrediction2.0>>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).→˓toDF()>>> model.transform(test2).head().newPrediction0.0>>> model_path = temp_path + "/ovr_model">>> model.save(model_path)>>> model2 = OneVsRestModel.load(model_path)>>> model2.transform(test0).head().newPrediction0.0>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.transform(test2).columns['features', 'rawPrediction', 'newPrediction']
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getClassifier() Gets the value of classifier or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParallelism() Gets the value of parallelism or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.
continues on next page
3.3. ML 699
pyspark Documentation, Release master
Table 242 – continued from previous pagegetRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setClassifier(value) Sets the value of classifier.setFeaturesCol(value) Sets the value of featuresCol.setLabelCol(value) Sets the value of labelCol.setParallelism(value) Sets the value of parallelism.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, label-
Col=”label”, predictionCol=”prediction”, raw-PredictionCol=”rawPrediction”, classifier=None,weightCol=None, parallelism=1): Sets params forOneVsRest.
setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
classifierfeaturesCollabelColparallelismparams Returns all params ordered by name.predictionColrawPredictionColweightCol
700 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This creates a deepcopy of the embedded paramMap, and copies the embedded and extra parameters over.
New in version 2.0.0.
Returns
OneVsRest Copy of this instance
Examples
extra [dict, optional] Extra parameters to copy to the new instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
3.3. ML 701
pyspark Documentation, Release master
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getClassifier()Gets the value of classifier or its default value.
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParallelism()Gets the value of parallelism or its default value.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setClassifier(value)Sets the value of classifier.
702 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 2.0.0.
setFeaturesCol(value)Sets the value of featuresCol.
setLabelCol(value)Sets the value of labelCol.
setParallelism(value)Sets the value of parallelism.
setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, rawPrediction-Col=”rawPrediction”, classifier=None, weightCol=None, parallelism=1): Sets params for OneVsRest.
New in version 2.0.0.
setPredictionCol(value)Sets the value of predictionCol.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
setWeightCol(value)Sets the value of weightCol.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
classifier = Param(parent='undefined', name='classifier', doc='base binary classifier')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
parallelism = Param(parent='undefined', name='parallelism', doc='the number of threads to use when running parallel algorithms (>= 1).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
OneVsRestModel
class pyspark.ml.classification.OneVsRestModel(models)Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for eachclass. Each example is scored against all k models, and the model with the highest score is picked to label theexample.
New in version 2.0.0.
3.3. ML 703
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getClassifier() Gets the value of classifier or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
704 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
classifierfeaturesCollabelColparams Returns all params ordered by name.predictionColrawPredictionColweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This creates a deepcopy of the embedded paramMap, and copies the embedded and extra parameters over.
New in version 2.0.0.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
OneVsRestModel Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getClassifier()Gets the value of classifier or its default value.
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getLabelCol()Gets the value of labelCol or its default value.
3.3. ML 705
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
setPredictionCol(value)Sets the value of predictionCol.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
706 Chapter 3. API Reference
pyspark Documentation, Release master
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
classifier = Param(parent='undefined', name='classifier', doc='base binary classifier')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
FMClassifier
class pyspark.ml.classification.FMClassifier(*args, **kwargs)Factorization Machines learning algorithm for classification.
Solver supports:
• gd (normal mini-batch gradient descent)
• adamW (default)
New in version 3.0.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.classification import FMClassifier>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> fm = FMClassifier(factorSize=2)>>> fm.setSeed(11)FMClassifier...>>> model = fm.fit(df)>>> model.getMaxIter()100>>> test0 = spark.createDataFrame([... (Vectors.dense(-1.0),),... (Vectors.dense(0.5),),... (Vectors.dense(1.0),),... (Vectors.dense(2.0),)], ["features"])>>> model.predictRaw(test0.head().features)DenseVector([22.13..., -22.13...])>>> model.predictProbability(test0.head().features)DenseVector([1.0, 0.0])>>> model.transform(test0).select("features", "probability").show(10, False)+--------+------------------------------------------+
(continues on next page)
3.3. ML 707
pyspark Documentation, Release master
(continued from previous page)
|features|probability |+--------+------------------------------------------+|[-1.0] |[0.9999999997574736,2.425264676902229E-10]||[0.5] |[0.47627851732981163,0.5237214826701884] ||[1.0] |[5.491554426243495E-4,0.9994508445573757] ||[2.0] |[2.005766663870645E-10,0.9999999997994233]|+--------+------------------------------------------+...>>> model.intercept-7.316665276826291>>> model.linearDenseVector([14.8232])>>> model.factorsDenseMatrix(1, 2, [0.0163, -0.0051], 1)>>> model_path = temp_path + "/fm_model">>> model.save(model_path)>>> model2 = FMClassificationModel.load(model_path)>>> model2.intercept-7.316665276826291>>> model2.linearDenseVector([14.8232])>>> model2.factorsDenseMatrix(1, 2, [0.0163, -0.0051], 1)>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.
continues on next page
708 Chapter 3. API Reference
pyspark Documentation, Release master
Table 246 – continued from previous pagegetInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFactorSize(value) Sets the value of factorSize.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setFitLinear(value) Sets the value of fitLinear.setInitStd(value) Sets the value of initStd.setLabelCol(value) Sets the value of labelCol.setMaxIter(value) Sets the value of maxIter.setMiniBatchFraction(value) Sets the value of miniBatchFraction.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets Params for FMClassifier.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setRegParam(value) Sets the value of regParam.setSeed(value) Sets the value of seed.setSolver(value) Sets the value of solver.setStepSize(value) Sets the value of stepSize.setThresholds(value) Sets the value of thresholds.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.
3.3. ML 709
pyspark Documentation, Release master
Attributes
factorSizefeaturesColfitInterceptfitLinearinitStdlabelColmaxIterminiBatchFractionparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParamseedsolverstepSizethresholdstolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
710 Chapter 3. API Reference
pyspark Documentation, Release master
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFactorSize()Gets the value of factorSize or its default value.
New in version 3.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getFitLinear()Gets the value of fitLinear or its default value.
New in version 3.0.0.
getInitStd()Gets the value of initStd or its default value.
New in version 3.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.
New in version 3.0.0.
3.3. ML 711
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSeed()Gets the value of seed or its default value.
getSolver()Gets the value of solver or its default value.
getStepSize()Gets the value of stepSize or its default value.
getThresholds()Gets the value of thresholds or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFactorSize(value)Sets the value of factorSize.
712 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.0.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setFitIntercept(value)Sets the value of fitIntercept.
New in version 3.0.0.
setFitLinear(value)Sets the value of fitLinear.
New in version 3.0.0.
setInitStd(value)Sets the value of initStd.
New in version 3.0.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setMaxIter(value)Sets the value of maxIter.
New in version 3.0.0.
setMiniBatchFraction(value)Sets the value of miniBatchFraction.
New in version 3.0.0.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", probabili-tyCol="probability", rawPredictionCol="rawPrediction", factorSize=8, fitIntercept=True,fitLinear=True, regParam=0.0, miniBatchFraction=1.0, initStd=0.01, maxIter=100, step-Size=1.0, tol=1e-6, solver="adamW", thresholds=None, seed=None)
Sets Params for FMClassifier.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setRegParam(value)Sets the value of regParam.
New in version 3.0.0.
3.3. ML 713
pyspark Documentation, Release master
setSeed(value)Sets the value of seed.
New in version 3.0.0.
setSolver(value)Sets the value of solver.
New in version 3.0.0.
setStepSize(value)Sets the value of stepSize.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
setTol(value)Sets the value of tol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')
initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
seed = Param(parent='undefined', name='seed', doc='random seed.')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
714 Chapter 3. API Reference
pyspark Documentation, Release master
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
FMClassificationModel
class pyspark.ml.classification.FMClassificationModel(java_model=None)Model fitted by FMClassifier.
New in version 3.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.getInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.
continues on next page
3.3. ML 715
pyspark Documentation, Release master
Table 248 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-
tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.summary() Gets summary (e.g.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
factorSizefactors Model factor term.featuresColfitInterceptfitLinearhasSummary Indicates whether a training summary exists for this
model instance.initStdintercept Model intercept.labelCollinear Model linear term.maxIterminiBatchFractionnumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParam
continues on next page
716 Chapter 3. API Reference
pyspark Documentation, Release master
Table 249 – continued from previous pageseedsolverstepSizethresholdstolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 3.1.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getFactorSize()Gets the value of factorSize or its default value.
New in version 3.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
3.3. ML 717
pyspark Documentation, Release master
getFitIntercept()Gets the value of fitIntercept or its default value.
getFitLinear()Gets the value of fitLinear or its default value.
New in version 3.0.0.
getInitStd()Gets the value of initStd or its default value.
New in version 3.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSeed()Gets the value of seed or its default value.
getSolver()Gets the value of solver or its default value.
getStepSize()Gets the value of stepSize or its default value.
getThresholds()Gets the value of thresholds or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
718 Chapter 3. API Reference
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictProbability(value)Predict the probability of each class given the features.
New in version 3.0.0.
predictRaw(value)Raw prediction for each possible label.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
New in version 3.0.0.
setThresholds(value)Sets the value of thresholds.
New in version 3.0.0.
summary()Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.
New in version 3.1.0.
3.3. ML 719
pyspark Documentation, Release master
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')
factorsModel factor term.
New in version 3.0.0.
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')
interceptModel intercept.
New in version 3.0.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
linearModel linear term.
New in version 3.0.0.
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')
numClassesNumber of classes (values which the label can take).
New in version 2.1.0.
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
720 Chapter 3. API Reference
pyspark Documentation, Release master
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
seed = Param(parent='undefined', name='seed', doc='random seed.')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
FMClassificationSummary
class pyspark.ml.classification.FMClassificationSummary(java_obj=None)Abstraction for FMClassifier Results for a given model.
New in version 3.1.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.pr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.continues on next page
3.3. ML 721
pyspark Documentation, Release master
Table 251 – continued from previous pagepredictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
722 Chapter 3. API Reference
pyspark Documentation, Release master
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
3.3. ML 723
pyspark Documentation, Release master
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
FMClassificationTrainingSummary
class pyspark.ml.classification.FMClassificationTrainingSummary(java_obj=None)Abstraction for FMClassifier Training results.
New in version 3.1.0.
Methods
fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.
724 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-
acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-
Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of
each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at
each iteration.pr Returns the precision-recall curve, which is a
Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.
precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-
cision) curve.predictionCol Field in “predictions” which gives the prediction of
each class.predictions Dataframe outputted by the model’s transform
method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-
call) curve.roc Returns the receiver operating characteristic (ROC)
curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.
scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.
totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each
instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.
Methods Documentation
fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).
New in version 3.1.0.
weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.
New in version 3.1.0.
3.3. ML 725
pyspark Documentation, Release master
Attributes Documentation
accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)
New in version 3.1.0.
areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.
New in version 3.1.0.
fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
New in version 3.1.0.
falsePositiveRateByLabelReturns false positive rate for each label (category).
New in version 3.1.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 3.1.0.
labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.
New in version 3.1.0.
Notes
In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.
New in version 3.1.0.
prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.
New in version 3.1.0.
precisionByLabelReturns precision for each label (category).
New in version 3.1.0.
precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.
New in version 3.1.0.
726 Chapter 3. API Reference
pyspark Documentation, Release master
predictionColField in “predictions” which gives the prediction of each class.
New in version 3.1.0.
predictionsDataframe outputted by the model’s transform method.
New in version 3.1.0.
recallByLabelReturns recall for each label (category).
New in version 3.1.0.
recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.
New in version 3.1.0.
rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
New in version 3.1.0.
Notes
Wikipedia reference
scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.
New in version 3.1.0.
totalIterationsNumber of training iterations until termination.
New in version 3.1.0.
truePositiveRateByLabelReturns true positive rate for each label (category).
New in version 3.1.0.
weightColField in “predictions” which gives the weight of each instance as a vector.
New in version 3.1.0.
weightedFalsePositiveRateReturns weighted false positive rate.
New in version 3.1.0.
weightedPrecisionReturns weighted averaged precision.
New in version 3.1.0.
weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)
New in version 3.1.0.
3.3. ML 727
pyspark Documentation, Release master
weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)
New in version 3.1.0.
3.3.5 Clustering
BisectingKMeans(*args, **kwargs) A bisecting k-means algorithm based on the paper“A comparison of document clustering techniques” bySteinbach, Karypis, and Kumar, with modification to fitSpark.
BisectingKMeansModel([java_model]) Model fitted by BisectingKMeans.BisectingKMeansSummary([java_obj]) Bisecting KMeans clustering results for a given model.KMeans(*args, **kwargs) K-means clustering with a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al).KMeansModel([java_model]) Model fitted by KMeans.KMeansSummary([java_obj]) Summary of KMeans.GaussianMixture(*args, **kwargs) GaussianMixture clustering.GaussianMixtureModel([java_model]) Model fitted by GaussianMixture.GaussianMixtureSummary([java_obj]) Gaussian mixture clustering results for a given model.LDA(*args, **kwargs) Latent Dirichlet Allocation (LDA), a topic model de-
signed for text documents.LDAModel([java_model]) Latent Dirichlet Allocation (LDA) model.LocalLDAModel([java_model]) Local (non-distributed) model fitted by LDA.DistributedLDAModel([java_model]) Distributed model fitted by LDA.PowerIterationClustering(*args, **kwargs) Power Iteration Clustering (PIC), a scalable graph clus-
tering algorithm developed by Lin and Cohen.From theabstract: PIC finds a very low-dimensional embeddingof a dataset using truncated power iteration on a normal-ized pair-wise similarity matrix of the data..
BisectingKMeans
class pyspark.ml.clustering.BisectingKMeans(*args, **kwargs)A bisecting k-means algorithm based on the paper “A comparison of document clustering techniques” by Stein-bach, Karypis, and Kumar, with modification to fit Spark. The algorithm starts from a single cluster that containsall points. Iteratively it finds divisible clusters on the bottom level and bisects each of them using k-means, untilthere are k leaf clusters in total or no leaf clusters are divisible. The bisecting steps of clusters on the same levelare grouped together to increase parallelism. If bisecting all divisible clusters on the bottom level would resultmore than k leaf clusters, larger clusters get higher priority.
New in version 2.0.0.
728 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]>>> df = spark.createDataFrame(data, ["features", "weighCol"])>>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0)>>> bkm.setMaxIter(10)BisectingKMeans...>>> bkm.getMaxIter()10>>> bkm.clear(bkm.maxIter)>>> bkm.setSeed(1)BisectingKMeans...>>> bkm.setWeightCol("weighCol")BisectingKMeans...>>> bkm.getSeed()1>>> bkm.clear(bkm.seed)>>> model = bkm.fit(df)>>> model.getMaxIter()20>>> model.setPredictionCol("newPrediction")BisectingKMeansModel...>>> model.predict(df.head().features)0>>> centers = model.clusterCenters()>>> len(centers)2>>> model.computeCost(df)2.0>>> model.hasSummaryTrue>>> summary = model.summary>>> summary.k2>>> summary.clusterSizes[2, 2]>>> summary.trainingCost4.000...>>> transformed = model.transform(df).select("features", "newPrediction")>>> rows = transformed.collect()>>> rows[0].newPrediction == rows[1].newPredictionTrue>>> rows[2].newPrediction == rows[3].newPredictionTrue>>> bkm_path = temp_path + "/bkm">>> bkm.save(bkm_path)>>> bkm2 = BisectingKMeans.load(bkm_path)>>> bkm2.getK()2>>> bkm2.getDistanceMeasure()'euclidean'>>> model_path = temp_path + "/bkm_model">>> model.save(model_path)>>> model2 = BisectingKMeansModel.load(model_path)>>> model2.hasSummary
(continues on next page)
3.3. ML 729
pyspark Documentation, Release master
(continued from previous page)
False>>> model.clusterCenters()[0] == model2.clusterCenters()[0]array([ True, True], dtype=bool)>>> model.clusterCenters()[1] == model2.clusterCenters()[1]array([ True, True], dtype=bool)>>> model.transform(df).take(1) == model2.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getDistanceMeasure() Gets the value of distanceMeasure or its defaultvalue.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getMaxIter() Gets the value of maxIter or its default value.getMinDivisibleClusterSize() Gets the value of minDivisibleClusterSize or its de-
fault value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).continues on next page
730 Chapter 3. API Reference
pyspark Documentation, Release master
Table 255 – continued from previous pageread() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDistanceMeasure(value) Sets the value of distanceMeasure.setFeaturesCol(value) Sets the value of featuresCol.setK(value) Sets the value of k.setMaxIter(value) Sets the value of maxIter.setMinDivisibleClusterSize(value) Sets the value of minDivisibleClusterSize.setParams(self, \*[, featuresCol, . . . ]) Sets params for BisectingKMeans.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
distanceMeasurefeaturesColkmaxIterminDivisibleClusterSizeparams Returns all params ordered by name.predictionColseedweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
3.3. ML 731
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getDistanceMeasure()Gets the value of distanceMeasure or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getMinDivisibleClusterSize()Gets the value of minDivisibleClusterSize or its default value.
New in version 2.0.0.
732 Chapter 3. API Reference
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDistanceMeasure(value)Sets the value of distanceMeasure.
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 2.0.0.
setK(value)Sets the value of k.
New in version 2.0.0.
setMaxIter(value)Sets the value of maxIter.
New in version 2.0.0.
setMinDivisibleClusterSize(value)Sets the value of minDivisibleClusterSize.
New in version 2.0.0.
3.3. ML 733
pyspark Documentation, Release master
setParams(self, \*, featuresCol="features", predictionCol="prediction", maxIter=20, seed=None,k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", weightCol=None)
Sets params for BisectingKMeans.
New in version 2.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 2.0.0.
setSeed(value)Sets the value of seed.
New in version 2.0.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
k = Param(parent='undefined', name='k', doc='The desired number of leaf clusters. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
minDivisibleClusterSize = Param(parent='undefined', name='minDivisibleClusterSize', doc='The minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
BisectingKMeansModel
class pyspark.ml.clustering.BisectingKMeansModel(java_model=None)Model fitted by BisectingKMeans.
New in version 2.0.0.
734 Chapter 3. API Reference
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
clusterCenters() Get the cluster centers, represented as a list ofNumPy arrays.
computeCost(dataset) Computes the sum of squared distances between theinput points and their corresponding cluster centers.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getDistanceMeasure() Gets the value of distanceMeasure or its defaultvalue.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getMaxIter() Gets the value of maxIter or its default value.getMinDivisibleClusterSize() Gets the value of minDivisibleClusterSize or its de-
fault value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.continues on next page
3.3. ML 735
pyspark Documentation, Release master
Table 257 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.
Attributes
distanceMeasurefeaturesColhasSummary Indicates whether a training summary exists for this
model instance.kmaxIterminDivisibleClusterSizeparams Returns all params ordered by name.predictionColseedsummary Gets summary (e.g.weightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
clusterCenters()Get the cluster centers, represented as a list of NumPy arrays.
New in version 2.0.0.
computeCost(dataset)Computes the sum of squared distances between the input points and their corresponding cluster centers.
Deprecated since version 3.0.0: It will be removed in future versions. Use ClusteringEvaluatorinstead. You can also get the cost on the training dataset in the summary.
New in version 2.0.0.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
736 Chapter 3. API Reference
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getDistanceMeasure()Gets the value of distanceMeasure or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getMinDivisibleClusterSize()Gets the value of minDivisibleClusterSize or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 737
pyspark Documentation, Release master
predict(value)Predict label for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
k = Param(parent='undefined', name='k', doc='The desired number of leaf clusters. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
minDivisibleClusterSize = Param(parent='undefined', name='minDivisibleClusterSize', doc='The minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
738 Chapter 3. API Reference
pyspark Documentation, Release master
summaryGets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An excep-tion is thrown if no summary exists.
New in version 2.1.0.
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
BisectingKMeansSummary
class pyspark.ml.clustering.BisectingKMeansSummary(java_obj=None)Bisecting KMeans clustering results for a given model.
New in version 2.1.0.
Attributes
cluster DataFrame of predicted cluster centers for eachtraining data point.
clusterSizes Size of (number of data points in) each cluster.featuresCol Name for column of features in predictions.k The number of clusters the model was trained with.numIter Number of iterations.predictionCol Name for column of predicted clusters in predic-
tions.predictions DataFrame produced by the model’s transform
method.trainingCost Sum of squared distances to the nearest centroid for
all points in the training dataset.
Attributes Documentation
clusterDataFrame of predicted cluster centers for each training data point.
New in version 2.1.0.
clusterSizesSize of (number of data points in) each cluster.
New in version 2.1.0.
featuresColName for column of features in predictions.
New in version 2.1.0.
kThe number of clusters the model was trained with.
New in version 2.1.0.
numIterNumber of iterations.
New in version 2.4.0.
3.3. ML 739
pyspark Documentation, Release master
predictionColName for column of predicted clusters in predictions.
New in version 2.1.0.
predictionsDataFrame produced by the model’s transform method.
New in version 2.1.0.
trainingCostSum of squared distances to the nearest centroid for all points in the training dataset. This is equivalent tosklearn’s inertia.
New in version 3.0.0.
KMeans
class pyspark.ml.clustering.KMeans(*args, **kwargs)K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al).
New in version 1.5.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]>>> df = spark.createDataFrame(data, ["features", "weighCol"])>>> kmeans = KMeans(k=2)>>> kmeans.setSeed(1)KMeans...>>> kmeans.setWeightCol("weighCol")KMeans...>>> kmeans.setMaxIter(10)KMeans...>>> kmeans.getMaxIter()10>>> kmeans.clear(kmeans.maxIter)>>> model = kmeans.fit(df)>>> model.getDistanceMeasure()'euclidean'>>> model.setPredictionCol("newPrediction")KMeansModel...>>> model.predict(df.head().features)0>>> centers = model.clusterCenters()>>> len(centers)2>>> transformed = model.transform(df).select("features", "newPrediction")>>> rows = transformed.collect()>>> rows[0].newPrediction == rows[1].newPredictionTrue>>> rows[2].newPrediction == rows[3].newPredictionTrue>>> model.hasSummaryTrue
(continues on next page)
740 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> summary = model.summary>>> summary.k2>>> summary.clusterSizes[2, 2]>>> summary.trainingCost4.0>>> kmeans_path = temp_path + "/kmeans">>> kmeans.save(kmeans_path)>>> kmeans2 = KMeans.load(kmeans_path)>>> kmeans2.getK()2>>> model_path = temp_path + "/kmeans_model">>> model.save(model_path)>>> model2 = KMeansModel.load(model_path)>>> model2.hasSummaryFalse>>> model.clusterCenters()[0] == model2.clusterCenters()[0]array([ True, True], dtype=bool)>>> model.clusterCenters()[1] == model2.clusterCenters()[1]array([ True, True], dtype=bool)>>> model.transform(df).take(1) == model2.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getDistanceMeasure() Gets the value of distanceMeasure or its defaultvalue.
getFeaturesCol() Gets the value of featuresCol or its default value.getInitMode() Gets the value of initModegetInitSteps() Gets the value of initStepsgetK() Gets the value of k
continues on next page
3.3. ML 741
pyspark Documentation, Release master
Table 260 – continued from previous pagegetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDistanceMeasure(value) Sets the value of distanceMeasure.setFeaturesCol(value) Sets the value of featuresCol.setInitMode(value) Sets the value of initMode.setInitSteps(value) Sets the value of initSteps.setK(value) Sets the value of k.setMaxIter(value) Sets the value of maxIter.setParams(self, \*[, featuresCol, . . . ]) Sets params for KMeans.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
distanceMeasurefeaturesColinitModeinitStepskmaxIterparams Returns all params ordered by name.predictionColseedtolweightCol
742 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
3.3. ML 743
pyspark Documentation, Release master
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getDistanceMeasure()Gets the value of distanceMeasure or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getInitMode()Gets the value of initMode
New in version 1.5.0.
getInitSteps()Gets the value of initSteps
New in version 1.5.0.
getK()Gets the value of k
New in version 1.5.0.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
744 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDistanceMeasure(value)Sets the value of distanceMeasure.
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 1.5.0.
setInitMode(value)Sets the value of initMode.
New in version 1.5.0.
setInitSteps(value)Sets the value of initSteps.
New in version 1.5.0.
setK(value)Sets the value of k.
New in version 1.5.0.
setMaxIter(value)Sets the value of maxIter.
New in version 1.5.0.
setParams(self, \*, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||",initSteps=2, tol=1e-4, maxIter=20, seed=None, distanceMeasure="euclidean", weight-Col=None)
Sets params for KMeans.
New in version 1.5.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 1.5.0.
setSeed(value)Sets the value of seed.
New in version 1.5.0.
setTol(value)Sets the value of tol.
New in version 1.5.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
3.3. ML 745
pyspark Documentation, Release master
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
initMode = Param(parent='undefined', name='initMode', doc='The initialization algorithm. This can be either "random" to choose random points as initial cluster centers, or "k-means||" to use a parallel variant of k-means++')
initSteps = Param(parent='undefined', name='initSteps', doc='The number of steps for k-means|| initialization mode. Must be > 0.')
k = Param(parent='undefined', name='k', doc='The number of clusters to create. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
KMeansModel
class pyspark.ml.clustering.KMeansModel(java_model=None)Model fitted by KMeans.
New in version 1.5.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
clusterCenters() Get the cluster centers, represented as a list ofNumPy arrays.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
continues on next page
746 Chapter 3. API Reference
pyspark Documentation, Release master
Table 262 – continued from previous pagegetDistanceMeasure() Gets the value of distanceMeasure or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getInitMode() Gets the value of initModegetInitSteps() Gets the value of initStepsgetK() Gets the value of kgetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an GeneralMLWriter instance for this ML
instance.
Attributes
distanceMeasurefeaturesColhasSummary Indicates whether a training summary exists for this
model instance.initModeinitStepskmaxIterparams Returns all params ordered by name.predictionColseedsummary Gets summary (e.g.tol
continues on next page
3.3. ML 747
pyspark Documentation, Release master
Table 263 – continued from previous pageweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
clusterCenters()Get the cluster centers, represented as a list of NumPy arrays.
New in version 1.5.0.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getDistanceMeasure()Gets the value of distanceMeasure or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getInitMode()Gets the value of initMode
New in version 1.5.0.
getInitSteps()Gets the value of initSteps
New in version 1.5.0.
748 Chapter 3. API Reference
pyspark Documentation, Release master
getK()Gets the value of k
New in version 1.5.0.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
3.3. ML 749
pyspark Documentation, Release master
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an GeneralMLWriter instance for this ML instance.
Attributes Documentation
distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
initMode = Param(parent='undefined', name='initMode', doc='The initialization algorithm. This can be either "random" to choose random points as initial cluster centers, or "k-means||" to use a parallel variant of k-means++')
initSteps = Param(parent='undefined', name='initSteps', doc='The number of steps for k-means|| initialization mode. Must be > 0.')
k = Param(parent='undefined', name='k', doc='The number of clusters to create. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
summaryGets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An excep-tion is thrown if no summary exists.
New in version 2.1.0.
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
750 Chapter 3. API Reference
pyspark Documentation, Release master
KMeansSummary
class pyspark.ml.clustering.KMeansSummary(java_obj=None)Summary of KMeans.
New in version 2.1.0.
Attributes
cluster DataFrame of predicted cluster centers for eachtraining data point.
clusterSizes Size of (number of data points in) each cluster.featuresCol Name for column of features in predictions.k The number of clusters the model was trained with.numIter Number of iterations.predictionCol Name for column of predicted clusters in predic-
tions.predictions DataFrame produced by the model’s transform
method.trainingCost K-means cost (sum of squared distances to the near-
est centroid for all points in the training dataset).
Attributes Documentation
clusterDataFrame of predicted cluster centers for each training data point.
New in version 2.1.0.
clusterSizesSize of (number of data points in) each cluster.
New in version 2.1.0.
featuresColName for column of features in predictions.
New in version 2.1.0.
kThe number of clusters the model was trained with.
New in version 2.1.0.
numIterNumber of iterations.
New in version 2.4.0.
predictionColName for column of predicted clusters in predictions.
New in version 2.1.0.
predictionsDataFrame produced by the model’s transform method.
New in version 2.1.0.
3.3. ML 751
pyspark Documentation, Release master
trainingCostK-means cost (sum of squared distances to the nearest centroid for all points in the training dataset). Thisis equivalent to sklearn’s inertia.
New in version 2.4.0.
GaussianMixture
class pyspark.ml.clustering.GaussianMixture(*args, **kwargs)GaussianMixture clustering. This class performs expectation maximization for multivariate Gaussian MixtureModels (GMMs). A GMM represents a composite distribution of independent Gaussian distributions withassociated “mixing” weights specifying each’s contribution to the composite.
Given a set of sample points, this class will maximize the log-likelihood for a mixture of k Gaussians, iteratinguntil the log-likelihood changes by less than convergenceTol, or until it has reached the max number of iterations.While this process is generally guaranteed to converge, it is not guaranteed to find a global optimum.
New in version 2.0.0.
Notes
For high-dimensional data (with many features), this algorithm may perform poorly. This is due to high-dimensional data (a) making it difficult to cluster at all (based on statistical/theoretical arguments) and (b)numerical issues with Gaussian distributions.
Examples
>>> from pyspark.ml.linalg import Vectors
>>> data = [(Vectors.dense([-0.1, -0.05 ]),),... (Vectors.dense([-0.01, -0.1]),),... (Vectors.dense([0.9, 0.8]),),... (Vectors.dense([0.75, 0.935]),),... (Vectors.dense([-0.83, -0.68]),),... (Vectors.dense([-0.91, -0.76]),)]>>> df = spark.createDataFrame(data, ["features"])>>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)>>> gm.getMaxIter()100>>> gm.setMaxIter(30)GaussianMixture...>>> gm.getMaxIter()30>>> model = gm.fit(df)>>> model.getAggregationDepth()2>>> model.getFeaturesCol()'features'>>> model.setPredictionCol("newPrediction")GaussianMixtureModel...>>> model.predict(df.head().features)2>>> model.predictProbability(df.head().features)DenseVector([0.0, 0.0, 1.0])
(continues on next page)
752 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> model.hasSummaryTrue>>> summary = model.summary>>> summary.k3>>> summary.clusterSizes[2, 2, 2]>>> summary.logLikelihood65.02945...>>> weights = model.weights>>> len(weights)3>>> gaussians = model.gaussians>>> len(gaussians)3>>> gaussians[0].meanDenseVector([0.825, 0.8675])>>> gaussians[0].covDenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], 0)>>> gaussians[1].meanDenseVector([-0.87, -0.72])>>> gaussians[1].covDenseMatrix(2, 2, [0.0016, 0.0016, 0.0016, 0.0016], 0)>>> gaussians[2].meanDenseVector([-0.055, -0.075])>>> gaussians[2].covDenseMatrix(2, 2, [0.002, -0.0011, -0.0011, 0.0006], 0)>>> model.gaussiansDF.select("mean").head()Row(mean=DenseVector([0.825, 0.8675]))>>> model.gaussiansDF.select("cov").head()Row(cov=DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], False))>>> transformed = model.transform(df).select("features", "newPrediction")>>> rows = transformed.collect()>>> rows[4].newPrediction == rows[5].newPredictionTrue>>> rows[2].newPrediction == rows[3].newPredictionTrue>>> gmm_path = temp_path + "/gmm">>> gm.save(gmm_path)>>> gm2 = GaussianMixture.load(gmm_path)>>> gm2.getK()3>>> model_path = temp_path + "/gmm_model">>> model.save(model_path)>>> model2 = GaussianMixtureModel.load(model_path)>>> model2.hasSummaryFalse>>> model2.weights == model.weightsTrue>>> model2.gaussians[0].mean == model.gaussians[0].meanTrue>>> model2.gaussians[0].cov == model.gaussians[0].covTrue>>> model2.gaussians[1].mean == model.gaussians[1].meanTrue>>> model2.gaussians[1].cov == model.gaussians[1].covTrue
(continues on next page)
3.3. ML 753
pyspark Documentation, Release master
(continued from previous page)
>>> model2.gaussians[2].mean == model.gaussians[2].meanTrue>>> model2.gaussians[2].cov == model.gaussians[2].covTrue>>> model2.gaussiansDF.select("mean").head()Row(mean=DenseVector([0.825, 0.8675]))>>> model2.gaussiansDF.select("cov").head()Row(cov=DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], False))>>> model.transform(df).take(1) == model2.transform(df).take(1)True>>> gm2.setWeightCol("weight")GaussianMixture...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of kgetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.continues on next page
754 Chapter 3. API Reference
pyspark Documentation, Release master
Table 265 – continued from previous pageisDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setFeaturesCol(value) Sets the value of featuresCol.setK(value) Sets the value of k.setMaxIter(value) Sets the value of maxIter.setParams(self, \*[, featuresCol, . . . ]) Sets params for GaussianMixture.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setSeed(value) Sets the value of seed.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthfeaturesColkmaxIterparams Returns all params ordered by name.predictionColprobabilityColseedtolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
3.3. ML 755
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k
New in version 2.0.0.
756 Chapter 3. API Reference
pyspark Documentation, Release master
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getSeed()Gets the value of seed or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAggregationDepth(value)Sets the value of aggregationDepth.
New in version 3.0.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 2.0.0.
setK(value)Sets the value of k.
New in version 2.0.0.
3.3. ML 757
pyspark Documentation, Release master
setMaxIter(value)Sets the value of maxIter.
New in version 2.0.0.
setParams(self, \*, featuresCol="features", predictionCol="prediction", k=2, probability-Col="probability", tol=0.01, maxIter=100, seed=None, aggregationDepth=2, weight-Col=None)
Sets params for GaussianMixture.
New in version 2.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 2.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 2.0.0.
setSeed(value)Sets the value of seed.
New in version 2.0.0.
setTol(value)Sets the value of tol.
New in version 2.0.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
k = Param(parent='undefined', name='k', doc='Number of independent Gaussians in the mixture model. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
758 Chapter 3. API Reference
pyspark Documentation, Release master
GaussianMixtureModel
class pyspark.ml.clustering.GaussianMixtureModel(java_model=None)Model fitted by GaussianMixture.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of kgetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict probability for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.
continues on next page
3.3. ML 759
pyspark Documentation, Release master
Table 267 – continued from previous pagesetFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthfeaturesColgaussians Array of MultivariateGaussian where gaus-
sians[i] represents the Multivariate Gaussian (Nor-mal) Distribution for Gaussian i
gaussiansDF Retrieve Gaussian distributions as a DataFrame.hasSummary Indicates whether a training summary exists for this
model instance.kmaxIterparams Returns all params ordered by name.predictionColprobabilityColseedsummary Gets summary (e.g.tolweightColweights Weight for each Gaussian distribution in the mixture.
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
760 Chapter 3. API Reference
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getSeed()Gets the value of seed or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 761
pyspark Documentation, Release master
predict(value)Predict label for the given features.
New in version 3.0.0.
predictProbability(value)Predict probability for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
gaussiansArray of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal)Distribution for Gaussian i
New in version 3.0.0.
762 Chapter 3. API Reference
pyspark Documentation, Release master
gaussiansDFRetrieve Gaussian distributions as a DataFrame. Each row represents a Gaussian Distribution. TheDataFrame has two columns: mean (Vector) and cov (Matrix).
New in version 2.0.0.
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
k = Param(parent='undefined', name='k', doc='Number of independent Gaussians in the mixture model. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
summaryGets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An excep-tion is thrown if no summary exists.
New in version 2.1.0.
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
weightsWeight for each Gaussian distribution in the mixture. This is a multinomial probability distribution overthe k Gaussians, where weights[i] is the weight for Gaussian i, and weights sum to 1.
New in version 2.0.0.
GaussianMixtureSummary
class pyspark.ml.clustering.GaussianMixtureSummary(java_obj=None)Gaussian mixture clustering results for a given model.
New in version 2.1.0.
Attributes
cluster DataFrame of predicted cluster centers for eachtraining data point.
clusterSizes Size of (number of data points in) each cluster.featuresCol Name for column of features in predictions.k The number of clusters the model was trained with.logLikelihood Total log-likelihood for this model on the given data.numIter Number of iterations.predictionCol Name for column of predicted clusters in predic-
tions.continues on next page
3.3. ML 763
pyspark Documentation, Release master
Table 269 – continued from previous pagepredictions DataFrame produced by the model’s transform
method.probability DataFrame of probabilities of each cluster for each
training data point.probabilityCol Name for column of predicted probability of each
cluster in predictions.
Attributes Documentation
clusterDataFrame of predicted cluster centers for each training data point.
New in version 2.1.0.
clusterSizesSize of (number of data points in) each cluster.
New in version 2.1.0.
featuresColName for column of features in predictions.
New in version 2.1.0.
kThe number of clusters the model was trained with.
New in version 2.1.0.
logLikelihoodTotal log-likelihood for this model on the given data.
New in version 2.2.0.
numIterNumber of iterations.
New in version 2.4.0.
predictionColName for column of predicted clusters in predictions.
New in version 2.1.0.
predictionsDataFrame produced by the model’s transform method.
New in version 2.1.0.
probabilityDataFrame of probabilities of each cluster for each training data point.
New in version 2.1.0.
probabilityColName for column of predicted probability of each cluster in predictions.
New in version 2.1.0.
764 Chapter 3. API Reference
pyspark Documentation, Release master
LDA
class pyspark.ml.clustering.LDA(*args, **kwargs)Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
Terminology:
• “term” = “word”: an element of the vocabulary
• “token”: instance of a term appearing in a document
• “topic”: multinomial distribution over terms representing some concept
• “document”: one piece of text, corresponding to one row in the input data
Original LDA paper (journal version): Blei, Ng, and Jordan. “Latent Dirichlet Allocation.” JMLR, 2003.
Input data (featuresCol): LDA is given a collection of documents as input data, via the featuresCol parameter.Each document is specified as a Vector of length vocabSize, where each entry is the count for the correspond-ing term (word) in the document. Feature transformers such as pyspark.ml.feature.Tokenizer andpyspark.ml.feature.CountVectorizer can be useful for converting text to word count vectors.
New in version 2.0.0.
Examples
>>> from pyspark.ml.linalg import Vectors, SparseVector>>> from pyspark.ml.clustering import LDA>>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])],... [2, SparseVector(2, {0: 1.0})],], ["id", "features"])>>> lda = LDA(k=2, seed=1, optimizer="em")>>> lda.setMaxIter(10)LDA...>>> lda.getMaxIter()10>>> lda.clear(lda.maxIter)>>> model = lda.fit(df)>>> model.setSeed(1)DistributedLDAModel...>>> model.getTopicDistributionCol()'topicDistribution'>>> model.isDistributed()True>>> localModel = model.toLocal()>>> localModel.isDistributed()False>>> model.vocabSize()2>>> model.describeTopics().show()+-----+-----------+--------------------+|topic|termIndices| termWeights|+-----+-----------+--------------------+| 0| [1, 0]|[0.50401530077160...|| 1| [0, 1]|[0.50401530077160...|+-----+-----------+--------------------+...>>> model.topicsMatrix()DenseMatrix(2, 2, [0.496, 0.504, 0.504, 0.496], 0)
(continues on next page)
3.3. ML 765
pyspark Documentation, Release master
(continued from previous page)
>>> lda_path = temp_path + "/lda">>> lda.save(lda_path)>>> sameLDA = LDA.load(lda_path)>>> distributed_model_path = temp_path + "/lda_distributed_model">>> model.save(distributed_model_path)>>> sameModel = DistributedLDAModel.load(distributed_model_path)>>> local_model_path = temp_path + "/lda_local_model">>> localModel.save(local_model_path)>>> sameLocalModel = LocalLDAModel.load(local_model_path)>>> model.transform(df).take(1) == sameLocalModel.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.
getDocConcentration() Gets the value of docConcentration or its de-fault value.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its
default value.getLearningDecay() Gets the value of learningDecay or its default
value.getLearningOffset() Gets the value of learningOffset or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOptimizeDocConcentration() Gets the value of
optimizeDocConcentration or its defaultvalue.
getOptimizer() Gets the value of optimizer or its default value.continues on next page
766 Chapter 3. API Reference
pyspark Documentation, Release master
Table 270 – continued from previous pagegetOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getTopicConcentration() Gets the value of topicConcentration or its
default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its
default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCheckpointInterval(value) Sets the value of checkpointInterval.setDocConcentration(value) Sets the value of docConcentration.setFeaturesCol(value) Sets the value of featuresCol.setK(value) Sets the value of k.setKeepLastCheckpoint(value) Sets the value of keepLastCheckpoint.setLearningDecay(value) Sets the value of learningDecay .setLearningOffset(value) Sets the value of learningOffset.setMaxIter(value) Sets the value of maxIter.setOptimizeDocConcentration(value) Sets the value of
optimizeDocConcentration.setOptimizer(value) Sets the value of optimizer.setParams(self, \*[, featuresCol, maxIter, . . . ]) Sets params for LDA.setSeed(value) Sets the value of seed.setSubsamplingRate(value) Sets the value of subsamplingRate.setTopicConcentration(value) Sets the value of topicConcentration.setTopicDistributionCol(value) Sets the value of topicDistributionCol.write() Returns an MLWriter instance for this ML instance.
Attributes
checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffset
continues on next page
3.3. ML 767
pyspark Documentation, Release master
Table 271 – continued from previous pagemaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.seedsubsamplingRatetopicConcentrationtopicDistributionCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
768 Chapter 3. API Reference
pyspark Documentation, Release master
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getDocConcentration()Gets the value of docConcentration or its default value.
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k or its default value.
New in version 2.0.0.
getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.
New in version 2.0.0.
getLearningDecay()Gets the value of learningDecay or its default value.
New in version 2.0.0.
getLearningOffset()Gets the value of learningOffset or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.
New in version 2.0.0.
getOptimizer()Gets the value of optimizer or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
3.3. ML 769
pyspark Documentation, Release master
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 2.0.0.
getTopicConcentration()Gets the value of topicConcentration or its default value.
New in version 2.0.0.
getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCheckpointInterval(value)Sets the value of checkpointInterval.
New in version 2.0.0.
setDocConcentration(value)Sets the value of docConcentration.
770 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> algo = LDA().setDocConcentration([0.1, 0.2])>>> algo.getDocConcentration()[0.1..., 0.2...]
New in version 2.0.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 2.0.0.
setK(value)Sets the value of k.
>>> algo = LDA().setK(10)>>> algo.getK()10
New in version 2.0.0.
setKeepLastCheckpoint(value)Sets the value of keepLastCheckpoint.
Examples
>>> algo = LDA().setKeepLastCheckpoint(False)>>> algo.getKeepLastCheckpoint()False
New in version 2.0.0.
setLearningDecay(value)Sets the value of learningDecay .
Examples
>>> algo = LDA().setLearningDecay(0.1)>>> algo.getLearningDecay()0.1...
New in version 2.0.0.
setLearningOffset(value)Sets the value of learningOffset.
3.3. ML 771
pyspark Documentation, Release master
Examples
>>> algo = LDA().setLearningOffset(100)>>> algo.getLearningOffset()100.0
New in version 2.0.0.
setMaxIter(value)Sets the value of maxIter.
New in version 2.0.0.
setOptimizeDocConcentration(value)Sets the value of optimizeDocConcentration.
Examples
>>> algo = LDA().setOptimizeDocConcentration(True)>>> algo.getOptimizeDocConcentration()True
New in version 2.0.0.
setOptimizer(value)Sets the value of optimizer. Currently only support ‘em’ and ‘online’.
Examples
>>> algo = LDA().setOptimizer("em")>>> algo.getOptimizer()'em'
New in version 2.0.0.
setParams(self, \*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10, k=10,optimizer="online", learningOffset=1024.0, learningDecay=0.51, subsamplingRate=0.05,optimizeDocConcentration=True, docConcentration=None, topicConcentration=None,topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
Sets params for LDA.
New in version 2.0.0.
setSeed(value)Sets the value of seed.
New in version 2.0.0.
setSubsamplingRate(value)Sets the value of subsamplingRate.
772 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> algo = LDA().setSubsamplingRate(0.1)>>> algo.getSubsamplingRate()0.1...
New in version 2.0.0.
setTopicConcentration(value)Sets the value of topicConcentration.
Examples
>>> algo = LDA().setTopicConcentration(0.5)>>> algo.getTopicConcentration()0.5...
New in version 2.0.0.
setTopicDistributionCol(value)Sets the value of topicDistributionCol.
Examples
>>> algo = LDA().setTopicDistributionCol("topicDistributionCol")>>> algo.getTopicDistributionCol()'topicDistributionCol'
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')
keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')
learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')
learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')
optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
3.3. ML 773
pyspark Documentation, Release master
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')
topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')
topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')
LDAModel
class pyspark.ml.clustering.LDAModel(java_model=None)Latent Dirichlet Allocation (LDA) model. This abstraction permits for different underlying representations,including local and distributed data structures.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
describeTopics([maxTermsPerTopic]) Return the topics described by their top-weightedterms.
estimatedDocConcentration() Value for LDA.docConcentration estimatedfrom data.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.
getDocConcentration() Gets the value of docConcentration or its de-fault value.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its
default value.getLearningDecay() Gets the value of learningDecay or its default
value.getLearningOffset() Gets the value of learningOffset or its default
value.getMaxIter() Gets the value of maxIter or its default value.
continues on next page
774 Chapter 3. API Reference
pyspark Documentation, Release master
Table 272 – continued from previous pagegetOptimizeDocConcentration() Gets the value of
optimizeDocConcentration or its defaultvalue.
getOptimizer() Gets the value of optimizer or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getTopicConcentration() Gets the value of topicConcentration or its
default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its
default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isDistributed() Indicates whether this instance is of type Distribut-
edLDAModelisSet(param) Checks whether a param is explicitly set by user.logLikelihood(dataset) Calculates a lower bound on the log likelihood of the
entire corpus.logPerplexity(dataset) Calculate an upper bound on perplexity.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setSeed(value) Sets the value of seed.setTopicDistributionCol(value) Sets the value of topicDistributionCol.topicsMatrix() Inferred topics, where each topic is represented by a
distribution over terms.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.vocabSize() Vocabulary size (number of terms or words in the
vocabulary)
Attributes
checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffsetmaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.
continues on next page
3.3. ML 775
pyspark Documentation, Release master
Table 273 – continued from previous pageseedsubsamplingRatetopicConcentrationtopicDistributionCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
describeTopics(maxTermsPerTopic=10)Return the topics described by their top-weighted terms.
New in version 2.0.0.
estimatedDocConcentration()Value for LDA.docConcentration estimated from data. If Online LDA was used and LDA.optimizeDocConcentration was set to false, then this returns the fixed (given) value for the LDA.docConcentration parameter.
New in version 2.0.0.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getDocConcentration()Gets the value of docConcentration or its default value.
776 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k or its default value.
New in version 2.0.0.
getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.
New in version 2.0.0.
getLearningDecay()Gets the value of learningDecay or its default value.
New in version 2.0.0.
getLearningOffset()Gets the value of learningOffset or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.
New in version 2.0.0.
getOptimizer()Gets the value of optimizer or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 2.0.0.
getTopicConcentration()Gets the value of topicConcentration or its default value.
New in version 2.0.0.
getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 777
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isDistributed()Indicates whether this instance is of type DistributedLDAModel
New in version 2.0.0.
isSet(param)Checks whether a param is explicitly set by user.
logLikelihood(dataset)Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDApaper (Hoffman et al., 2010).
Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.
New in version 2.0.0.
logPerplexity(dataset)Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper(Hoffman et al., 2010).
Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.
New in version 2.0.0.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
New in version 3.0.0.
setTopicDistributionCol(value)Sets the value of topicDistributionCol.
New in version 3.0.0.
topicsMatrix()Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of sizevocabSize x k, where each column is a topic. No guarantees are given about the ordering of the topics.
778 Chapter 3. API Reference
pyspark Documentation, Release master
Warning: If this model is actually a DistributedLDAModel instance produced by theExpectation-Maximization (“em”) optimizer, then this method could involve collecting a large amountof data to the driver (on the order of vocabSize x k).
New in version 2.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
vocabSize()Vocabulary size (number of terms or words in the vocabulary)
New in version 2.0.0.
Attributes Documentation
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')
keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')
learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')
learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')
optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')
topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')
topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')
3.3. ML 779
pyspark Documentation, Release master
LocalLDAModel
class pyspark.ml.clustering.LocalLDAModel(java_model=None)Local (non-distributed) model fitted by LDA. This model stores the inferred topics only; it does not store infoabout the training dataset.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
describeTopics([maxTermsPerTopic]) Return the topics described by their top-weightedterms.
estimatedDocConcentration() Value for LDA.docConcentration estimatedfrom data.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.
getDocConcentration() Gets the value of docConcentration or its de-fault value.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its
default value.getLearningDecay() Gets the value of learningDecay or its default
value.getLearningOffset() Gets the value of learningOffset or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOptimizeDocConcentration() Gets the value of
optimizeDocConcentration or its defaultvalue.
getOptimizer() Gets the value of optimizer or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.
continues on next page
780 Chapter 3. API Reference
pyspark Documentation, Release master
Table 274 – continued from previous pagegetSubsamplingRate() Gets the value of subsamplingRate or its default
value.getTopicConcentration() Gets the value of topicConcentration or its
default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its
default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isDistributed() Indicates whether this instance is of type Distribut-
edLDAModelisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).logLikelihood(dataset) Calculates a lower bound on the log likelihood of the
entire corpus.logPerplexity(dataset) Calculate an upper bound on perplexity.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setSeed(value) Sets the value of seed.setTopicDistributionCol(value) Sets the value of topicDistributionCol.topicsMatrix() Inferred topics, where each topic is represented by a
distribution over terms.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.vocabSize() Vocabulary size (number of terms or words in the
vocabulary)write() Returns an MLWriter instance for this ML instance.
Attributes
checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffsetmaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.seedsubsamplingRate
continues on next page
3.3. ML 781
pyspark Documentation, Release master
Table 275 – continued from previous pagetopicConcentrationtopicDistributionCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
describeTopics(maxTermsPerTopic=10)Return the topics described by their top-weighted terms.
New in version 2.0.0.
estimatedDocConcentration()Value for LDA.docConcentration estimated from data. If Online LDA was used and LDA.optimizeDocConcentration was set to false, then this returns the fixed (given) value for the LDA.docConcentration parameter.
New in version 2.0.0.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getDocConcentration()Gets the value of docConcentration or its default value.
New in version 2.0.0.
782 Chapter 3. API Reference
pyspark Documentation, Release master
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k or its default value.
New in version 2.0.0.
getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.
New in version 2.0.0.
getLearningDecay()Gets the value of learningDecay or its default value.
New in version 2.0.0.
getLearningOffset()Gets the value of learningOffset or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.
New in version 2.0.0.
getOptimizer()Gets the value of optimizer or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 2.0.0.
getTopicConcentration()Gets the value of topicConcentration or its default value.
New in version 2.0.0.
getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
3.3. ML 783
pyspark Documentation, Release master
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isDistributed()Indicates whether this instance is of type DistributedLDAModel
New in version 2.0.0.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
logLikelihood(dataset)Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDApaper (Hoffman et al., 2010).
Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.
New in version 2.0.0.
logPerplexity(dataset)Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper(Hoffman et al., 2010).
Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.
New in version 2.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
New in version 3.0.0.
setTopicDistributionCol(value)Sets the value of topicDistributionCol.
New in version 3.0.0.
784 Chapter 3. API Reference
pyspark Documentation, Release master
topicsMatrix()Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of sizevocabSize x k, where each column is a topic. No guarantees are given about the ordering of the topics.
Warning: If this model is actually a DistributedLDAModel instance produced by theExpectation-Maximization (“em”) optimizer, then this method could involve collecting a large amountof data to the driver (on the order of vocabSize x k).
New in version 2.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
vocabSize()Vocabulary size (number of terms or words in the vocabulary)
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')
keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')
learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')
learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')
optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')
3.3. ML 785
pyspark Documentation, Release master
topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')
topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')
DistributedLDAModel
class pyspark.ml.clustering.DistributedLDAModel(java_model=None)Distributed model fitted by LDA. This type of model is currently only produced by Expectation-Maximization(EM).
This model stores the inferred topics, the full training dataset, and the topic distribution for each training docu-ment.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
describeTopics([maxTermsPerTopic]) Return the topics described by their top-weightedterms.
estimatedDocConcentration() Value for LDA.docConcentration estimatedfrom data.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCheckpointFiles() If using checkpointing and LDA.keepLastCheckpoint is set to true, thenthere may be saved checkpoint files.
getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.
getDocConcentration() Gets the value of docConcentration or its de-fault value.
getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its
default value.getLearningDecay() Gets the value of learningDecay or its default
value.getLearningOffset() Gets the value of learningOffset or its default
value.getMaxIter() Gets the value of maxIter or its default value.
continues on next page
786 Chapter 3. API Reference
pyspark Documentation, Release master
Table 276 – continued from previous pagegetOptimizeDocConcentration() Gets the value of
optimizeDocConcentration or its defaultvalue.
getOptimizer() Gets the value of optimizer or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getTopicConcentration() Gets the value of topicConcentration or its
default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its
default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isDistributed() Indicates whether this instance is of type Distribut-
edLDAModelisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).logLikelihood(dataset) Calculates a lower bound on the log likelihood of the
entire corpus.logPerplexity(dataset) Calculate an upper bound on perplexity.logPrior() Log probability of the current parameter estimate:
log P(topics, topic distributions for docs | alpha, eta)read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setSeed(value) Sets the value of seed.setTopicDistributionCol(value) Sets the value of topicDistributionCol.toLocal() Convert this distributed model to a local representa-
tion.topicsMatrix() Inferred topics, where each topic is represented by a
distribution over terms.trainingLogLikelihood() Log likelihood of the observed tokens in the train-
ing set, given the current parameter estimates: logP(docs | topics, topic distributions for docs, Dirichlethyperparameters)
transform(dataset[, params]) Transforms the input dataset with optional parame-ters.
vocabSize() Vocabulary size (number of terms or words in thevocabulary)
write() Returns an MLWriter instance for this ML instance.
3.3. ML 787
pyspark Documentation, Release master
Attributes
checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffsetmaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.seedsubsamplingRatetopicConcentrationtopicDistributionCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
describeTopics(maxTermsPerTopic=10)Return the topics described by their top-weighted terms.
New in version 2.0.0.
estimatedDocConcentration()Value for LDA.docConcentration estimated from data. If Online LDA was used and LDA.optimizeDocConcentration was set to false, then this returns the fixed (given) value for the LDA.docConcentration parameter.
New in version 2.0.0.
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra
788 Chapter 3. API Reference
pyspark Documentation, Release master
values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCheckpointFiles()If using checkpointing and LDA.keepLastCheckpoint is set to true, then there may be saved check-point files. This method is provided so that users can manage those files.
New in version 2.0.0.
Returns
list List of checkpoint files from training
Notes
Removing the checkpoints can cause failures if a partition is lost and is needed by certainDistributedLDAModel methods. Reference counting will clean up the checkpoints when this modeland derivative data go out of scope.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getDocConcentration()Gets the value of docConcentration or its default value.
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getK()Gets the value of k or its default value.
New in version 2.0.0.
getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.
New in version 2.0.0.
getLearningDecay()Gets the value of learningDecay or its default value.
New in version 2.0.0.
getLearningOffset()Gets the value of learningOffset or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.
New in version 2.0.0.
3.3. ML 789
pyspark Documentation, Release master
getOptimizer()Gets the value of optimizer or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 2.0.0.
getTopicConcentration()Gets the value of topicConcentration or its default value.
New in version 2.0.0.
getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isDistributed()Indicates whether this instance is of type DistributedLDAModel
New in version 2.0.0.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
logLikelihood(dataset)Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDApaper (Hoffman et al., 2010).
Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.
New in version 2.0.0.
790 Chapter 3. API Reference
pyspark Documentation, Release master
logPerplexity(dataset)Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper(Hoffman et al., 2010).
Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.
New in version 2.0.0.
logPrior()Log probability of the current parameter estimate: log P(topics, topic distributions for docs | alpha, eta)
New in version 2.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
New in version 3.0.0.
setTopicDistributionCol(value)Sets the value of topicDistributionCol.
New in version 3.0.0.
toLocal()Convert this distributed model to a local representation. This discards info about the training dataset.
Warning: This involves collecting a large topicsMatrix() to the driver.
New in version 2.0.0.
topicsMatrix()Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of sizevocabSize x k, where each column is a topic. No guarantees are given about the ordering of the topics.
Warning: If this model is actually a DistributedLDAModel instance produced by theExpectation-Maximization (“em”) optimizer, then this method could involve collecting a large amountof data to the driver (on the order of vocabSize x k).
New in version 2.0.0.
3.3. ML 791
pyspark Documentation, Release master
trainingLogLikelihood()Log likelihood of the observed tokens in the training set, given the current parameter estimates: log P(docs| topics, topic distributions for docs, Dirichlet hyperparameters)
Notes
• This excludes the prior; for that, use logPrior().
• Even with logPrior(), this is NOT the same as the data log likelihood given the hyperparam-eters.
• This is computed from the topic distributions computed during training. If you calllogLikelihood() on the same training dataset, the topic distributions will be computedagain, possibly giving different results.
New in version 2.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
vocabSize()Vocabulary size (number of terms or words in the vocabulary)
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')
keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')
learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')
learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')
optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')
792 Chapter 3. API Reference
pyspark Documentation, Release master
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')
topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')
topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')
PowerIterationClustering
class pyspark.ml.clustering.PowerIterationClustering(*args, **kwargs)Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen. Fromthe abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on anormalized pair-wise similarity matrix of the data.
This class is not yet an Estimator/Transformer, use assignClusters() method to run the PowerItera-tionClustering algorithm.
New in version 2.4.0.
Notes
See Wikipedia on Spectral clustering
Examples
>>> data = [(1, 0, 0.5),... (2, 0, 0.5), (2, 1, 0.7),... (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9),... (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1),... (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)]>>> df = spark.createDataFrame(data).toDF("src", "dst", "weight").repartition(1)>>> pic = PowerIterationClustering(k=2, weightCol="weight")>>> pic.setMaxIter(40)PowerIterationClustering...>>> assignments = pic.assignClusters(df)>>> assignments.sort(assignments.id).show(truncate=False)+---+-------+|id |cluster|+---+-------+|0 |0 ||1 |0 ||2 |0 ||3 |0 ||4 |0 ||5 |1 |+---+-------+...>>> pic_path = temp_path + "/pic">>> pic.save(pic_path)>>> pic2 = PowerIterationClustering.load(pic_path)>>> pic2.getK()
(continues on next page)
3.3. ML 793
pyspark Documentation, Release master
(continued from previous page)
2>>> pic2.getMaxIter()40>>> pic2.assignClusters(df).take(6) == assignments.take(6)True
Methods
assignClusters(dataset) Run the PIC algorithm and returns a cluster assign-ment for each input vertex.
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getDstCol() Gets the value of dstCol or its default value.getInitMode() Gets the value of initMode or its default value.getK() Gets the value of k or its default value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSrcCol() Gets the value of srcCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDstCol(value) Sets the value of dstCol.setInitMode(value) Sets the value of initMode.setK(value) Sets the value of k.
continues on next page
794 Chapter 3. API Reference
pyspark Documentation, Release master
Table 278 – continued from previous pagesetMaxIter(value) Sets the value of maxIter.setParams(self, \*[, k, maxIter, initMode, . . . ]) Sets params for PowerIterationClustering.setSrcCol(value) Sets the value of srcCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
dstColinitModekmaxIterparams Returns all params ordered by name.srcColweightCol
Methods Documentation
assignClusters(dataset)Run the PIC algorithm and returns a cluster assignment for each input vertex.
Parameters
dataset [pyspark.sql.DataFrame] A dataset with columns src, dst, weight represent-ing the affinity matrix, which is the matrix A in the PIC paper. Suppose the src columnvalue is i, the dst column value is j, the weight column value is similarity s„ij„ which mustbe nonnegative. This is a symmetric matrix and hence s„ij„ = s„ji„. For any (i, j) withnonzero similarity, there should be either (i, j, s„ij„) or (j, i, s„ji„) in the input. Rows withi = j are ignored, because we assume s„ij„ = 0.0.
Returns
pyspark.sql.DataFrame A dataset that contains columns of vertex id and the corre-sponding cluster for the id. The schema of it will be: - id: Long - cluster: Int
New in version 2.4.0: ..
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
3.3. ML 795
pyspark Documentation, Release master
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getDstCol()Gets the value of dstCol or its default value.
New in version 2.4.0.
getInitMode()Gets the value of initMode or its default value.
New in version 2.4.0.
getK()Gets the value of k or its default value.
New in version 2.4.0.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSrcCol()Gets the value of srcCol or its default value.
New in version 2.4.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
796 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDstCol(value)Sets the value of dstCol.
New in version 2.4.0.
setInitMode(value)Sets the value of initMode.
New in version 2.4.0.
setK(value)Sets the value of k.
New in version 2.4.0.
setMaxIter(value)Sets the value of maxIter.
New in version 2.4.0.
setParams(self, \*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weight-Col=None)
Sets params for PowerIterationClustering.
New in version 2.4.0.
setSrcCol(value)Sets the value of srcCol.
New in version 2.4.0.
setWeightCol(value)Sets the value of weightCol.
New in version 2.4.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
dstCol = Param(parent='undefined', name='dstCol', doc='Name of the input column for destination vertex IDs.')
initMode = Param(parent='undefined', name='initMode', doc="The initialization algorithm. This can be either 'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum of similarities with other vertices. Supported options: 'random' and 'degree'.")
k = Param(parent='undefined', name='k', doc='The number of clusters to create. Must be > 1.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
srcCol = Param(parent='undefined', name='srcCol', doc='Name of the input column for source vertex IDs.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 797
pyspark Documentation, Release master
3.3.6 ML Functions
array_to_vector(col) Converts a column of array of numeric type into a col-umn of dense vectors in MLlib
vector_to_array(col[, dtype]) Converts a column of MLlib sparse/dense vectors into acolumn of dense arrays.
pyspark.ml.functions.array_to_vector
pyspark.ml.functions.array_to_vector(col)Converts a column of array of numeric type into a column of dense vectors in MLlib
New in version 3.1.0.
Parameters
col [pyspark.sql.Column or str] Input column
Returns
pyspark.sql.Column The converted column of MLlib dense vectors.
Examples
>>> from pyspark.ml.functions import array_to_vector>>> df1 = spark.createDataFrame([([1.5, 2.5],),], schema='v1 array<double>')>>> df1.select(array_to_vector('v1').alias('vec1')).collect()[Row(vec1=DenseVector([1.5, 2.5]))]>>> df2 = spark.createDataFrame([([1.5, 3.5],),], schema='v1 array<float>')>>> df2.select(array_to_vector('v1').alias('vec1')).collect()[Row(vec1=DenseVector([1.5, 3.5]))]>>> df3 = spark.createDataFrame([([1, 3],),], schema='v1 array<int>')>>> df3.select(array_to_vector('v1').alias('vec1')).collect()[Row(vec1=DenseVector([1.0, 3.0]))]
pyspark.ml.functions.vector_to_array
pyspark.ml.functions.vector_to_array(col, dtype='float64')Converts a column of MLlib sparse/dense vectors into a column of dense arrays.
New in version 3.0.0.
Parameters
col [pyspark.sql.Column or str] Input column
dtype [str, optional] The data type of the output array. Valid values: “float64” or “float32”.
Returns
pyspark.sql.Column The converted column of dense arrays.
798 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.functions import vector_to_array>>> from pyspark.mllib.linalg import Vectors as OldVectors>>> df = spark.createDataFrame([... (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)),... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]),... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))],... ["vec", "oldVec"])>>> df1 = df.select(vector_to_array("vec").alias("vec"),... vector_to_array("oldVec").alias("oldVec"))>>> df1.collect()[Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]
>>> df2 = df.select(vector_to_array("vec", "float32").alias("vec"),... vector_to_array("oldVec", "float32").alias("oldVec"))>>> df2.collect()[Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]
>>> df1.schema.fields[StructField(vec,ArrayType(DoubleType,false),false),StructField(oldVec,ArrayType(DoubleType,false),false)]>>> df2.schema.fields[StructField(vec,ArrayType(FloatType,false),false),StructField(oldVec,ArrayType(FloatType,false),false)]
3.3.7 Vector and Matrix
Vector()DenseVector(ar) A dense vector represented by a value array.SparseVector(size, *args) A simple sparse vector class for passing data to MLlib.Vectors() Factory methods for working with vectors.Matrix(numRows, numCols[, isTransposed])DenseMatrix(numRows, numCols, values[, . . . ]) Column-major dense matrix.SparseMatrix(numRows, numCols, colPtrs, . . . ) Sparse Matrix stored in CSC format.Matrices()
Vector
class pyspark.ml.linalg.Vector
3.3. ML 799
pyspark Documentation, Release master
Methods
toArray() Convert the vector into an numpy.ndarray
Methods Documentation
toArray()Convert the vector into an numpy.ndarray
Returns numpy.ndarray
DenseVector
class pyspark.ml.linalg.DenseVector(ar)A dense vector represented by a value array. We use numpy array for storage and arithmetics will be delegatedto the underlying numpy array.
Examples
>>> v = Vectors.dense([1.0, 2.0])>>> u = Vectors.dense([3.0, 4.0])>>> v + uDenseVector([4.0, 6.0])>>> 2 - vDenseVector([1.0, 0.0])>>> v / 2DenseVector([0.5, 1.0])>>> v * uDenseVector([3.0, 8.0])>>> u / vDenseVector([3.0, 2.0])>>> u % 2DenseVector([1.0, 0.0])>>> -vDenseVector([-1.0, -2.0])
Methods
dot(other) Compute the dot product of two Vectors.norm(p) Calculates the norm of a DenseVector.numNonzeros() Number of nonzero elements.squared_distance(other) Squared distance of two Vectors.toArray() Returns the underlying numpy.ndarray
800 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
values Returns the underlying numpy.ndarray
Methods Documentation
dot(other)Compute the dot product of two Vectors. We support (Numpy array, list, SparseVector, or SciPy sparse)and a target NumPy array that is either 1- or 2-dimensional. Equivalent to calling numpy.dot of the twovectors.
Examples
>>> dense = DenseVector(array.array('d', [1., 2.]))>>> dense.dot(dense)5.0>>> dense.dot(SparseVector(2, [0, 1], [2., 1.]))4.0>>> dense.dot(range(1, 3))5.0>>> dense.dot(np.array(range(1, 3)))5.0>>> dense.dot([1.,])Traceback (most recent call last):
...AssertionError: dimension mismatch>>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F'))array([ 5., 11.])>>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F'))Traceback (most recent call last):
...AssertionError: dimension mismatch
norm(p)Calculates the norm of a DenseVector.
Examples
>>> a = DenseVector([0, -1, 2, -3])>>> a.norm(2)3.7...>>> a.norm(1)6.0
numNonzeros()Number of nonzero elements. This scans all active values and count non zeros
squared_distance(other)Squared distance of two Vectors.
3.3. ML 801
pyspark Documentation, Release master
Examples
>>> dense1 = DenseVector(array.array('d', [1., 2.]))>>> dense1.squared_distance(dense1)0.0>>> dense2 = np.array([2., 1.])>>> dense1.squared_distance(dense2)2.0>>> dense3 = [2., 1.]>>> dense1.squared_distance(dense3)2.0>>> sparse1 = SparseVector(2, [0, 1], [2., 1.])>>> dense1.squared_distance(sparse1)2.0>>> dense1.squared_distance([1.,])Traceback (most recent call last):
...AssertionError: dimension mismatch>>> dense1.squared_distance(SparseVector(1, [0,], [1.,]))Traceback (most recent call last):
...AssertionError: dimension mismatch
toArray()Returns the underlying numpy.ndarray
Attributes Documentation
valuesReturns the underlying numpy.ndarray
SparseVector
class pyspark.ml.linalg.SparseVector(size, *args)A simple sparse vector class for passing data to MLlib. Users may alternatively pass SciPy’s {scipy.sparse} datatypes.
Methods
dot(other) Dot product with a SparseVector or 1- or 2-dimensional Numpy array.
norm(p) Calculates the norm of a SparseVector.numNonzeros() Number of nonzero elements.squared_distance(other) Squared distance from a SparseVector or 1-
dimensional NumPy array.toArray() Returns a copy of this SparseVector as a 1-
dimensional numpy.ndarray.
802 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
dot(other)Dot product with a SparseVector or 1- or 2-dimensional Numpy array.
Examples
>>> a = SparseVector(4, [1, 3], [3.0, 4.0])>>> a.dot(a)25.0>>> a.dot(array.array('d', [1., 2., 3., 4.]))22.0>>> b = SparseVector(4, [2], [1.0])>>> a.dot(b)0.0>>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))array([ 22., 22.])>>> a.dot([1., 2., 3.])Traceback (most recent call last):
...AssertionError: dimension mismatch>>> a.dot(np.array([1., 2.]))Traceback (most recent call last):
...AssertionError: dimension mismatch>>> a.dot(DenseVector([1., 2.]))Traceback (most recent call last):
...AssertionError: dimension mismatch>>> a.dot(np.zeros((3, 2)))Traceback (most recent call last):
...AssertionError: dimension mismatch
norm(p)Calculates the norm of a SparseVector.
Examples
>>> a = SparseVector(4, [0, 1], [3., -4.])>>> a.norm(1)7.0>>> a.norm(2)5.0
numNonzeros()Number of nonzero elements. This scans all active values and count non zeros.
squared_distance(other)Squared distance from a SparseVector or 1-dimensional NumPy array.
3.3. ML 803
pyspark Documentation, Release master
Examples
>>> a = SparseVector(4, [1, 3], [3.0, 4.0])>>> a.squared_distance(a)0.0>>> a.squared_distance(array.array('d', [1., 2., 3., 4.]))11.0>>> a.squared_distance(np.array([1., 2., 3., 4.]))11.0>>> b = SparseVector(4, [2], [1.0])>>> a.squared_distance(b)26.0>>> b.squared_distance(a)26.0>>> b.squared_distance([1., 2.])Traceback (most recent call last):
...AssertionError: dimension mismatch>>> b.squared_distance(SparseVector(3, [1,], [1.0,]))Traceback (most recent call last):
...AssertionError: dimension mismatch
toArray()Returns a copy of this SparseVector as a 1-dimensional numpy.ndarray.
Vectors
class pyspark.ml.linalg.VectorsFactory methods for working with vectors.
Notes
Dense vectors are simply represented as NumPy array objects, so there is no need to covert them for use inMLlib. For sparse vectors, the factory methods in this class create an MLlib-compatible type, or users can passin SciPy’s scipy.sparse column vectors.
Methods
dense(*elements) Create a dense vector of 64-bit floats from a Pythonlist or numbers.
norm(vector, p) Find norm of the given vector.sparse(size, *args) Create a sparse vector, using either a dictionary, a
list of (index, value) pairs, or two separate arrays ofindices and values (sorted by index).
squared_distance(v1, v2) Squared distance between two vectors.zeros(size)
804 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
static dense(*elements)Create a dense vector of 64-bit floats from a Python list or numbers.
Examples
>>> Vectors.dense([1, 2, 3])DenseVector([1.0, 2.0, 3.0])>>> Vectors.dense(1.0, 2.0)DenseVector([1.0, 2.0])
static norm(vector, p)Find norm of the given vector.
static sparse(size, *args)Create a sparse vector, using either a dictionary, a list of (index, value) pairs, or two separate arrays ofindices and values (sorted by index).
Parameters
size [int] Size of the vector.
args Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indicesand values.
Examples
>>> Vectors.sparse(4, {1: 1.0, 3: 5.5})SparseVector(4, {1: 1.0, 3: 5.5})>>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)])SparseVector(4, {1: 1.0, 3: 5.5})>>> Vectors.sparse(4, [1, 3], [1.0, 5.5])SparseVector(4, {1: 1.0, 3: 5.5})
static squared_distance(v1, v2)Squared distance between two vectors. a and b can be of type SparseVector, DenseVector, np.ndarray orarray.array.
Examples
>>> a = Vectors.sparse(4, [(0, 1), (3, 4)])>>> b = Vectors.dense([2, 5, 4, 1])>>> a.squared_distance(b)51.0
static zeros(size)
3.3. ML 805
pyspark Documentation, Release master
Matrix
class pyspark.ml.linalg.Matrix(numRows, numCols, isTransposed=False)
Methods
toArray() Returns its elements in a numpy.ndarray.
Methods Documentation
toArray()Returns its elements in a numpy.ndarray.
DenseMatrix
class pyspark.ml.linalg.DenseMatrix(numRows, numCols, values, isTransposed=False)Column-major dense matrix.
Methods
toArray() Return a numpy.ndarraytoSparse() Convert to SparseMatrix
Methods Documentation
toArray()Return a numpy.ndarray
Examples
>>> m = DenseMatrix(2, 2, range(4))>>> m.toArray()array([[ 0., 2.],
[ 1., 3.]])
toSparse()Convert to SparseMatrix
806 Chapter 3. API Reference
pyspark Documentation, Release master
SparseMatrix
class pyspark.ml.linalg.SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, is-Transposed=False)
Sparse Matrix stored in CSC format.
Methods
toArray() Return a numpy.ndarraytoDense()
Methods Documentation
toArray()Return a numpy.ndarray
toDense()
Matrices
class pyspark.ml.linalg.Matrices
Methods
dense(numRows, numCols, values) Create a DenseMatrixsparse(numRows, numCols, colPtrs, . . . ) Create a SparseMatrix
Methods Documentation
static dense(numRows, numCols, values)Create a DenseMatrix
static sparse(numRows, numCols, colPtrs, rowIndices, values)Create a SparseMatrix
3.3.8 Recommendation
ALS(*args, **kwargs) Alternating Least Squares (ALS) matrix factorization.ALSModel([java_model]) Model fitted by ALS.
3.3. ML 807
pyspark Documentation, Release master
ALS
class pyspark.ml.recommendation.ALS(*args, **kwargs)Alternating Least Squares (ALS) matrix factorization.
ALS attempts to estimate the ratings matrix R as the product of two lower-rank matrices, X and Y, i.e. X *Yt = R. Typically these approximations are called ‘factor’ matrices. The general approach is iterative. Duringeach iteration, one of the factor matrices is held constant, while the other is solved for using least squares. Thenewly-solved factor matrix is then held constant while solving for the other factor matrix.
This is a blocked implementation of the ALS factorization algorithm that groups the two sets of factors (referredto as “users” and “products”) into blocks and reduces communication by only sending one copy of each uservector to each product block on each iteration, and only for the product blocks that need that user’s featurevector. This is achieved by pre-computing some information about the ratings matrix to determine the “out-links” of each user (which blocks of products it will contribute to) and “in-link” information for each product(which of the feature vectors it receives from each user block it will depend on). This allows us to send only anarray of feature vectors between each user block and product block, and have the product block find the users’ratings and update the products based on these messages.
For implicit preference data, the algorithm used is based on “Collaborative Filtering for Implicit FeedbackDatasets”,, adapted for the blocked approach used here.
Essentially instead of finding the low-rank approximations to the rating matrix R, this finds the approximationsfor a preference matrix P where the elements of P are 1 if r > 0 and 0 if r <= 0. The ratings then act as‘confidence’ values related to strength of indicated user preferences rather than explicit ratings given to items.
New in version 1.4.0.
Notes
The input rating dataframe to the ALS implementation should be deterministic. Nondeterministic data cancause failure during fitting ALS model. For example, an order-sensitive operation like sampling after a repar-tition makes dataframe output nondeterministic, like df.repartition(2).sample(False, 0.5, 1618). Checkpointingsampled dataframe or adding a sort before sampling can help make the dataframe deterministic.
Examples
>>> df = spark.createDataFrame(... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2,→˓5.0)],... ["user", "item", "rating"])>>> als = ALS(rank=10, seed=0)>>> als.setMaxIter(5)ALS...>>> als.getMaxIter()5>>> als.setRegParam(0.1)ALS...>>> als.getRegParam()0.1>>> als.clear(als.regParam)>>> model = als.fit(df)>>> model.getBlockSize()4096>>> model.getUserCol()
(continues on next page)
808 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
'user'>>> model.setUserCol("user")ALSModel...>>> model.getItemCol()'item'>>> model.setPredictionCol("newPrediction")ALS...>>> model.rank10>>> model.userFactors.orderBy("id").collect()[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]>>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])>>> predictions[0]Row(user=0, item=2, newPrediction=0.692910...)>>> predictions[1]Row(user=1, item=0, newPrediction=3.473569...)>>> predictions[2]Row(user=2, item=0, newPrediction=-0.899198...)>>> user_recs = model.recommendForAllUsers(3)>>> user_recs.where(user_recs.user == 0) .select("recommendations.item",→˓"recommendations.rating").collect()[Row(item=[0, 1, 2], rating=[3.910..., 1.997..., 0.692...])]>>> item_recs = model.recommendForAllItems(3)>>> item_recs.where(item_recs.item == 2) .select("recommendations.user",→˓"recommendations.rating").collect()[Row(user=[2, 1, 0], rating=[4.892..., 3.991..., 0.692...])]>>> user_subset = df.where(df.user == 2)>>> user_subset_recs = model.recommendForUserSubset(user_subset, 3)>>> user_subset_recs.select("recommendations.item", "recommendations.rating").→˓first()Row(item=[2, 1, 0], rating=[4.892..., 1.076..., -0.899...])>>> item_subset = df.where(df.item == 0)>>> item_subset_recs = model.recommendForItemSubset(item_subset, 3)>>> item_subset_recs.select("recommendations.user", "recommendations.rating").→˓first()Row(user=[0, 1, 2], rating=[3.910..., 3.473..., -0.899...])>>> als_path = temp_path + "/als">>> als.save(als_path)>>> als2 = ALS.load(als_path)>>> als.getMaxIter()5>>> model_path = temp_path + "/als_model">>> model.save(model_path)>>> model2 = ALSModel.load(model_path)>>> model.rank == model2.rankTrue>>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())True>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())True>>> model.transform(test).take(1) == model2.transform(test).take(1)True
3.3. ML 809
pyspark Documentation, Release master
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getAlpha() Gets the value of alpha or its default value.getBlockSize() Gets the value of blockSize or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getColdStartStrategy() Gets the value of coldStartStrategy or its default
value.getFinalStorageLevel() Gets the value of finalStorageLevel or its default
value.getImplicitPrefs() Gets the value of implicitPrefs or its default value.getIntermediateStorageLevel() Gets the value of intermediateStorageLevel or its de-
fault value.getItemCol() Gets the value of itemCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getNonnegative() Gets the value of nonnegative or its default value.getNumItemBlocks() Gets the value of numItemBlocks or its default value.getNumUserBlocks() Gets the value of numUserBlocks or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRank() Gets the value of rank or its default value.getRatingCol() Gets the value of ratingCol or its default value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getUserCol() Gets the value of userCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.continues on next page
810 Chapter 3. API Reference
pyspark Documentation, Release master
Table 292 – continued from previous pageisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAlpha(value) Sets the value of alpha.setBlockSize(value) Sets the value of blockSize.setCheckpointInterval(value) Sets the value of checkpointInterval.setColdStartStrategy(value) Sets the value of coldStartStrategy .setFinalStorageLevel(value) Sets the value of finalStorageLevel.setImplicitPrefs(value) Sets the value of implicitPrefs.setIntermediateStorageLevel(value) Sets the value of
intermediateStorageLevel.setItemCol(value) Sets the value of itemCol.setMaxIter(value) Sets the value of maxIter.setNonnegative(value) Sets the value of nonnegative.setNumBlocks(value) Sets both numUserBlocks and
numItemBlocks to the specific value.setNumItemBlocks(value) Sets the value of numItemBlocks.setNumUserBlocks(value) Sets the value of numUserBlocks.setParams(self, \*[, rank, maxIter, . . . ]) Sets params for ALS.setPredictionCol(value) Sets the value of predictionCol.setRank(value) Sets the value of rank.setRatingCol(value) Sets the value of ratingCol.setRegParam(value) Sets the value of regParam.setSeed(value) Sets the value of seed.setUserCol(value) Sets the value of userCol.write() Returns an MLWriter instance for this ML instance.
Attributes
alphablockSizecheckpointIntervalcoldStartStrategyfinalStorageLevelimplicitPrefsintermediateStorageLevelitemColmaxIternonnegativenumItemBlocksnumUserBlocksparams Returns all params ordered by name.predictionColrankratingCol
continues on next page
3.3. ML 811
pyspark Documentation, Release master
Table 293 – continued from previous pageregParamseeduserCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
812 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAlpha()Gets the value of alpha or its default value.
New in version 1.4.0.
getBlockSize()Gets the value of blockSize or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getColdStartStrategy()Gets the value of coldStartStrategy or its default value.
New in version 2.2.0.
getFinalStorageLevel()Gets the value of finalStorageLevel or its default value.
New in version 2.0.0.
getImplicitPrefs()Gets the value of implicitPrefs or its default value.
New in version 1.4.0.
getIntermediateStorageLevel()Gets the value of intermediateStorageLevel or its default value.
New in version 2.0.0.
getItemCol()Gets the value of itemCol or its default value.
New in version 1.4.0.
getMaxIter()Gets the value of maxIter or its default value.
getNonnegative()Gets the value of nonnegative or its default value.
New in version 1.4.0.
getNumItemBlocks()Gets the value of numItemBlocks or its default value.
New in version 1.4.0.
getNumUserBlocks()Gets the value of numUserBlocks or its default value.
New in version 1.4.0.
3.3. ML 813
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRank()Gets the value of rank or its default value.
New in version 1.4.0.
getRatingCol()Gets the value of ratingCol or its default value.
New in version 1.4.0.
getRegParam()Gets the value of regParam or its default value.
getSeed()Gets the value of seed or its default value.
getUserCol()Gets the value of userCol or its default value.
New in version 1.4.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAlpha(value)Sets the value of alpha.
New in version 1.4.0.
setBlockSize(value)Sets the value of blockSize.
New in version 3.0.0.
814 Chapter 3. API Reference
pyspark Documentation, Release master
setCheckpointInterval(value)Sets the value of checkpointInterval.
setColdStartStrategy(value)Sets the value of coldStartStrategy .
New in version 2.2.0.
setFinalStorageLevel(value)Sets the value of finalStorageLevel.
New in version 2.0.0.
setImplicitPrefs(value)Sets the value of implicitPrefs.
New in version 1.4.0.
setIntermediateStorageLevel(value)Sets the value of intermediateStorageLevel.
New in version 2.0.0.
setItemCol(value)Sets the value of itemCol.
New in version 1.4.0.
setMaxIter(value)Sets the value of maxIter.
setNonnegative(value)Sets the value of nonnegative.
New in version 1.4.0.
setNumBlocks(value)Sets both numUserBlocks and numItemBlocks to the specific value.
New in version 1.4.0.
setNumItemBlocks(value)Sets the value of numItemBlocks.
New in version 1.4.0.
setNumUserBlocks(value)Sets the value of numUserBlocks.
New in version 1.4.0.
setParams(self, \*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, rat-ingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStor-ageLevel="MEMORY_AND_DISK", finalStorageLevel="MEMORY_AND_DISK",coldStartStrategy="nan", blockSize=4096)
Sets params for ALS.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
setRank(value)Sets the value of rank.
3.3. ML 815
pyspark Documentation, Release master
New in version 1.4.0.
setRatingCol(value)Sets the value of ratingCol.
New in version 1.4.0.
setRegParam(value)Sets the value of regParam.
setSeed(value)Sets the value of seed.
setUserCol(value)Sets the value of userCol.
New in version 1.4.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
alpha = Param(parent='undefined', name='alpha', doc='alpha for implicit preference')
blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
coldStartStrategy = Param(parent='undefined', name='coldStartStrategy', doc="strategy for dealing with unknown or new users/items at prediction time. This may be useful in cross-validation or production scenarios, for handling user/item ids the model has not seen in the training data. Supported values: 'nan', 'drop'.")
finalStorageLevel = Param(parent='undefined', name='finalStorageLevel', doc='StorageLevel for ALS model factors.')
implicitPrefs = Param(parent='undefined', name='implicitPrefs', doc='whether to use implicit preference')
intermediateStorageLevel = Param(parent='undefined', name='intermediateStorageLevel', doc="StorageLevel for intermediate datasets. Cannot be 'NONE'.")
itemCol = Param(parent='undefined', name='itemCol', doc='column name for item ids. Ids must be within the integer value range.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
nonnegative = Param(parent='undefined', name='nonnegative', doc='whether to use nonnegative constraint for least squares')
numItemBlocks = Param(parent='undefined', name='numItemBlocks', doc='number of item blocks')
numUserBlocks = Param(parent='undefined', name='numUserBlocks', doc='number of user blocks')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
rank = Param(parent='undefined', name='rank', doc='rank of the factorization')
ratingCol = Param(parent='undefined', name='ratingCol', doc='column name for ratings')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
seed = Param(parent='undefined', name='seed', doc='random seed.')
userCol = Param(parent='undefined', name='userCol', doc='column name for user ids. Ids must be within the integer value range.')
816 Chapter 3. API Reference
pyspark Documentation, Release master
ALSModel
class pyspark.ml.recommendation.ALSModel(java_model=None)Model fitted by ALS.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBlockSize() Gets the value of blockSize or its default value.getColdStartStrategy() Gets the value of coldStartStrategy or its default
value.getItemCol() Gets the value of itemCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getUserCol() Gets the value of userCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.recommendForAllItems(numUsers) Returns top numUsers users recommended for each
item, for all items.recommendForAllUsers(numItems) Returns top numItems items recommended for each
user, for all users.recommendForItemSubset(dataset, nu-mUsers)
Returns top numUsers users recommended for eachitem id in the input data set.
recommendForUserSubset(dataset, nu-mItems)
Returns top numItems items recommended for eachuser id in the input data set.
continues on next page
3.3. ML 817
pyspark Documentation, Release master
Table 294 – continued from previous pagesave(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBlockSize(value) Sets the value of blockSize.setColdStartStrategy(value) Sets the value of coldStartStrategy .setItemCol(value) Sets the value of itemCol.setPredictionCol(value) Sets the value of predictionCol.setUserCol(value) Sets the value of userCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
blockSizecoldStartStrategyitemColitemFactors a DataFrame that stores item factors in two columns:
id and featuresparams Returns all params ordered by name.predictionColrank rank of the matrix factorization modeluserColuserFactors a DataFrame that stores user factors in two columns:
id and features
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra
818 Chapter 3. API Reference
pyspark Documentation, Release master
values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBlockSize()Gets the value of blockSize or its default value.
getColdStartStrategy()Gets the value of coldStartStrategy or its default value.
New in version 2.2.0.
getItemCol()Gets the value of itemCol or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getUserCol()Gets the value of userCol or its default value.
New in version 1.4.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
recommendForAllItems(numUsers)Returns top numUsers users recommended for each item, for all items.
New in version 2.2.0.
Parameters
numUsers [int] max number of recommendations for each item
3.3. ML 819
pyspark Documentation, Release master
Returns
pyspark.sql.DataFrame a DataFrame of (itemCol, recommendations), where recom-mendations are stored as an array of (userCol, rating) Rows.
recommendForAllUsers(numItems)Returns top numItems items recommended for each user, for all users.
New in version 2.2.0.
Parameters
numItems [int] max number of recommendations for each user
Returns
pyspark.sql.DataFrame a DataFrame of (userCol, recommendations), where recom-mendations are stored as an array of (itemCol, rating) Rows.
recommendForItemSubset(dataset, numUsers)Returns top numUsers users recommended for each item id in the input data set. Note that if there areduplicate ids in the input dataset, only one set of recommendations per unique id will be returned.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] a DataFrame containing a column of item ids. Thecolumn name must match itemCol.
numUsers [int] max number of recommendations for each item
Returns
pyspark.sql.DataFrame a DataFrame of (itemCol, recommendations), where recom-mendations are stored as an array of (userCol, rating) Rows.
recommendForUserSubset(dataset, numItems)Returns top numItems items recommended for each user id in the input data set. Note that if there areduplicate ids in the input dataset, only one set of recommendations per unique id will be returned.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] a DataFrame containing a column of user ids. Thecolumn name must match userCol.
numItems [int] max number of recommendations for each user
Returns
pyspark.sql.DataFrame a DataFrame of (userCol, recommendations), where recom-mendations are stored as an array of (itemCol, rating) Rows.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBlockSize(value)Sets the value of blockSize.
New in version 3.0.0.
820 Chapter 3. API Reference
pyspark Documentation, Release master
setColdStartStrategy(value)Sets the value of coldStartStrategy .
New in version 3.0.0.
setItemCol(value)Sets the value of itemCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setUserCol(value)Sets the value of userCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')
coldStartStrategy = Param(parent='undefined', name='coldStartStrategy', doc="strategy for dealing with unknown or new users/items at prediction time. This may be useful in cross-validation or production scenarios, for handling user/item ids the model has not seen in the training data. Supported values: 'nan', 'drop'.")
itemCol = Param(parent='undefined', name='itemCol', doc='column name for item ids. Ids must be within the integer value range.')
itemFactorsa DataFrame that stores item factors in two columns: id and features
New in version 1.4.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
rankrank of the matrix factorization model
New in version 1.4.0.
userCol = Param(parent='undefined', name='userCol', doc='column name for user ids. Ids must be within the integer value range.')
3.3. ML 821
pyspark Documentation, Release master
userFactorsa DataFrame that stores user factors in two columns: id and features
New in version 1.4.0.
3.3.9 Regression
AFTSurvivalRegression(*args, **kwargs) Accelerated Failure Time (AFT) Model Survival Re-gression
AFTSurvivalRegressionModel([java_model]) Model fitted by AFTSurvivalRegression.DecisionTreeRegressor(*args, **kwargs) Decision tree learning algorithm for regression.It sup-
ports both continuous and categorical features..
DecisionTreeRegressionModel([java_model]) Model fitted by DecisionTreeRegressor.GBTRegressor(*args, **kwargs) Gradient-Boosted Trees (GBTs) learning algorithm for
regression.It supports both continuous and categoricalfeatures..
GBTRegressionModel([java_model]) Model fitted by GBTRegressor.GeneralizedLinearRegression(*args,**kwargs)
Generalized Linear Regression.
GeneralizedLinearRegressionModel([java_model])Model fitted by GeneralizedLinearRegression.GeneralizedLinearRegressionSummary([java_obj])Generalized linear regression results evaluated on a
dataset.GeneralizedLinearRegressionTrainingSummary([. . . ])Generalized linear regression training results.IsotonicRegression(*args, **kwargs) Currently implemented using parallelized pool adjacent
violators algorithm.IsotonicRegressionModel([java_model]) Model fitted by IsotonicRegression.LinearRegression(*args, **kwargs) Linear regression.LinearRegressionModel([java_model]) Model fitted by LinearRegression.LinearRegressionSummary([java_obj]) Linear regression results evaluated on a dataset.LinearRegressionTrainingSummary([java_obj])Linear regression training results.RandomForestRegressor(*args, **kwargs) Random Forest learning algorithm for regression.It sup-
ports both continuous and categorical features..
RandomForestRegressionModel([java_model]) Model fitted by RandomForestRegressor.FMRegressor(*args, **kwargs) Factorization Machines learning algorithm for regres-
sion.FMRegressionModel([java_model]) Model fitted by FMRegressor.
822 Chapter 3. API Reference
pyspark Documentation, Release master
AFTSurvivalRegression
class pyspark.ml.regression.AFTSurvivalRegression(*args, **kwargs)Accelerated Failure Time (AFT) Model Survival Regression
Fit a parametric AFT survival regression model based on the Weibull distribution of the survival time.
Notes
For more information see Wikipedia page on AFT Model
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0), 1.0),... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])>>> aftsr = AFTSurvivalRegression()>>> aftsr.setMaxIter(10)AFTSurvivalRegression...>>> aftsr.getMaxIter()10>>> aftsr.clear(aftsr.maxIter)>>> model = aftsr.fit(df)>>> model.getMaxBlockSizeInMB()0.0>>> model.setFeaturesCol("features")AFTSurvivalRegressionModel...>>> model.predict(Vectors.dense(6.3))1.0>>> model.predictQuantiles(Vectors.dense(6.3))DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.→˓6052])>>> model.transform(df).show()+-------+---------+------+----------+| label| features|censor|prediction|+-------+---------+------+----------+| 1.0| [1.0]| 1.0| 1.0||1.0E-40|(1,[],[])| 0.0| 1.0|+-------+---------+------+----------+...>>> aftsr_path = temp_path + "/aftsr">>> aftsr.save(aftsr_path)>>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)>>> aftsr2.getMaxIter()100>>> model_path = temp_path + "/aftsr_model">>> model.save(model_path)>>> model2 = AFTSurvivalRegressionModel.load(model_path)>>> model.coefficients == model2.coefficientsTrue>>> model.intercept == model2.interceptTrue>>> model.scale == model2.scaleTrue
(continues on next page)
3.3. ML 823
pyspark Documentation, Release master
(continued from previous page)
>>> model.transform(df).take(1) == model2.transform(df).take(1)True
New in version 1.6.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getCensorCol() Gets the value of censorCol or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getQuantileProbabilities() Gets the value of quantileProbabilities or its default
value.getQuantilesCol() Gets the value of quantilesCol or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.
continues on next page
824 Chapter 3. API Reference
pyspark Documentation, Release master
Table 297 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setCensorCol(value) Sets the value of censorCol.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-
belCol=”label”, predictionCol=”prediction”,fitIntercept=True, maxIter=100, tol=1E-6, cen-sorCol=”censor”, quantileProbabilities=[0.01,0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], quan-tilesCol=None, aggregationDepth=2, maxBlock-SizeInMB=0.0):
setPredictionCol(value) Sets the value of predictionCol.setQuantileProbabilities(value) Sets the value of quantileProbabilities.setQuantilesCol(value) Sets the value of quantilesCol.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthcensorColfeaturesColfitInterceptlabelColmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColquantileProbabilitiesquantilesColtol
3.3. ML 825
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
826 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getCensorCol()Gets the value of censorCol or its default value.
New in version 1.6.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getQuantileProbabilities()Gets the value of quantileProbabilities or its default value.
New in version 1.6.0.
getQuantilesCol()Gets the value of quantilesCol or its default value.
New in version 1.6.0.
getTol()Gets the value of tol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
3.3. ML 827
pyspark Documentation, Release master
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAggregationDepth(value)Sets the value of aggregationDepth.
New in version 2.1.0.
setCensorCol(value)Sets the value of censorCol.
New in version 1.6.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setFitIntercept(value)Sets the value of fitIntercept.
New in version 1.6.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.
New in version 3.1.0.
setMaxIter(value)Sets the value of maxIter.
New in version 1.6.0.
setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, fitInter-cept=True, maxIter=100, tol=1E-6, censorCol=”censor”, quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5,0.75, 0.9, 0.95, 0.99], quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0):
New in version 1.6.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setQuantileProbabilities(value)Sets the value of quantileProbabilities.
New in version 1.6.0.
828 Chapter 3. API Reference
pyspark Documentation, Release master
setQuantilesCol(value)Sets the value of quantilesCol.
New in version 1.6.0.
setTol(value)Sets the value of tol.
New in version 1.6.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
censorCol = Param(parent='undefined', name='censorCol', doc='censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
quantileProbabilities = Param(parent='undefined', name='quantileProbabilities', doc='quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.')
quantilesCol = Param(parent='undefined', name='quantilesCol', doc='quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
AFTSurvivalRegressionModel
class pyspark.ml.regression.AFTSurvivalRegressionModel(java_model=None)Model fitted by AFTSurvivalRegression.
New in version 1.6.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
continues on next page
3.3. ML 829
pyspark Documentation, Release master
Table 299 – continued from previous pageexplainParams() Returns the documentation of all params with their
optionally default values and user-supplied values.extractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getCensorCol() Gets the value of censorCol or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getQuantileProbabilities() Gets the value of quantileProbabilities or its default
value.getQuantilesCol() Gets the value of quantilesCol or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictQuantiles(features) Predicted Quantilesread() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setQuantileProbabilities(value) Sets the value of quantileProbabilities.setQuantilesCol(value) Sets the value of quantilesCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
830 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
aggregationDepthcensorColcoefficients Model coefficients.featuresColfitInterceptintercept Model intercept.labelColmaxBlockSizeInMBmaxIternumFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColquantileProbabilitiesquantilesColscale Model scale parameter.tol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
3.3. ML 831
pyspark Documentation, Release master
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getCensorCol()Gets the value of censorCol or its default value.
New in version 1.6.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getQuantileProbabilities()Gets the value of quantileProbabilities or its default value.
New in version 1.6.0.
getQuantilesCol()Gets the value of quantilesCol or its default value.
New in version 1.6.0.
getTol()Gets the value of tol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
832 Chapter 3. API Reference
pyspark Documentation, Release master
predictQuantiles(features)Predicted Quantiles
New in version 2.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setQuantileProbabilities(value)Sets the value of quantileProbabilities.
New in version 3.0.0.
setQuantilesCol(value)Sets the value of quantilesCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
censorCol = Param(parent='undefined', name='censorCol', doc='censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.')
coefficientsModel coefficients.
New in version 2.0.0.
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
3.3. ML 833
pyspark Documentation, Release master
interceptModel intercept.
New in version 1.6.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
quantileProbabilities = Param(parent='undefined', name='quantileProbabilities', doc='quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.')
quantilesCol = Param(parent='undefined', name='quantilesCol', doc='quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.')
scaleModel scale parameter.
New in version 1.6.0.
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
DecisionTreeRegressor
class pyspark.ml.regression.DecisionTreeRegressor(*args, **kwargs)Decision tree learning algorithm for regression. It supports both continuous and categorical features.
New in version 1.4.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> dt = DecisionTreeRegressor(maxDepth=2)>>> dt.setVarianceCol("variance")DecisionTreeRegressor...>>> model = dt.fit(df)>>> model.getVarianceCol()'variance'>>> model.setLeafCol("leafId")DecisionTreeRegressionModel...>>> model.depth1>>> model.numNodes3>>> model.featureImportancesSparseVector(1, {0: 1.0})
(continues on next page)
834 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> model.numFeatures1>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> result = model.transform(test0).head()>>> result.prediction0.0>>> model.predictLeaf(test0.head().features)0.0>>> result.leafId0.0>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> dtr_path = temp_path + "/dtr">>> dt.save(dtr_path)>>> dt2 = DecisionTreeRegressor.load(dtr_path)>>> dt2.getMaxDepth()2>>> model_path = temp_path + "/dtr_model">>> model.save(model_path)>>> model2 = DecisionTreeRegressionModel.load(model_path)>>> model.numNodes == model2.numNodesTrue>>> model.depth == model2.depthTrue>>> model.transform(test1).head().variance0.0>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> df3 = spark.createDataFrame([... (1.0, 0.2, Vectors.dense(1.0)),... (1.0, 0.8, Vectors.dense(1.0)),... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])>>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol=→˓"variance")>>> model3 = dt3.fit(df3)>>> print(model3.toDebugString)DecisionTreeRegressionModel...depth=1, numNodes=3...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
3.3. ML 835
pyspark Documentation, Release master
Table 301 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getVarianceCol() Gets the value of varianceCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.
continues on next page
836 Chapter 3. API Reference
pyspark Documentation, Release master
Table 301 – continued from previous pagesetMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of
minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for the DecisionTreeRegressor.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setVarianceCol(value) Sets the value of varianceCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
cacheNodeIdscheckpointIntervalfeaturesColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColseedsupportedImpuritiesvarianceColweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
3.3. ML 837
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
838 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getVarianceCol()Gets the value of varianceCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
3.3. ML 839
pyspark Documentation, Release master
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCacheNodeIds(value)Sets the value of cacheNodeIds.
New in version 1.4.0.
setCheckpointInterval(value)Sets the value of checkpointInterval.
New in version 1.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setImpurity(value)Sets the value of impurity .
New in version 1.4.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setMaxBins(value)Sets the value of maxBins.
New in version 1.4.0.
setMaxDepth(value)Sets the value of maxDepth.
New in version 1.4.0.
setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.
New in version 1.4.0.
setMinInfoGain(value)Sets the value of minInfoGain.
New in version 1.4.0.
setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.
New in version 1.4.0.
setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.
New in version 3.0.0.
840 Chapter 3. API Reference
pyspark Documentation, Release master
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",seed=None, varianceCol=None, weightCol=None, leafCol="", minWeightFractionPerN-ode=0.0)
Sets params for the DecisionTreeRegressor.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
setVarianceCol(value)Sets the value of varianceCol.
New in version 2.0.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
supportedImpurities = ['variance']
3.3. ML 841
pyspark Documentation, Release master
varianceCol = Param(parent='undefined', name='varianceCol', doc='column name for the biased sample variance of prediction.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
DecisionTreeRegressionModel
class pyspark.ml.regression.DecisionTreeRegressionModel(java_model=None)Model fitted by DecisionTreeRegressor.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getVarianceCol() Gets the value of varianceCol or its default value.getWeightCol() Gets the value of weightCol or its default value.
continues on next page
842 Chapter 3. API Reference
pyspark Documentation, Release master
Table 303 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the
feature vector.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setVarianceCol(value) Sets the value of varianceCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
cacheNodeIdscheckpointIntervaldepth Return depth of the decision tree.featureImportances Estimate of the importance of each feature.featuresColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumFeatures Returns the number of features the model was trained
on.numNodes Return number of nodes of the decision tree.params Returns all params ordered by name.predictionColseedsupportedImpuritiestoDebugString Full description of model.varianceColweightCol
3.3. ML 843
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
844 Chapter 3. API Reference
pyspark Documentation, Release master
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getVarianceCol()Gets the value of varianceCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
3.3. ML 845
pyspark Documentation, Release master
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setVarianceCol(value)Sets the value of varianceCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
depthReturn depth of the decision tree.
New in version 1.5.0.
featureImportancesEstimate of the importance of each feature.
This generalizes the idea of “Gini” importance to other losses, following the explanation of Gini im-portance from “Random Forests” documentation by Leo Breiman and Adele Cutler, and following theimplementation from scikit-learn.
This feature importance is calculated as follows:
• importance(feature j) = sum (over nodes which split on feature j) of the gain, where gain is scaledby the number of instances passing through node
• Normalize importances for tree to sum to 1.
New in version 2.0.0.
846 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
Feature importance for single decision trees can have high variance due to correlated predictor variables.Consider using a RandomForestRegressor to determine feature importance instead.
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
numNodesReturn number of nodes of the decision tree.
New in version 1.5.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
supportedImpurities = ['variance']
toDebugStringFull description of model.
New in version 2.0.0.
varianceCol = Param(parent='undefined', name='varianceCol', doc='column name for the biased sample variance of prediction.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
GBTRegressor
class pyspark.ml.regression.GBTRegressor(*args, **kwargs)Gradient-Boosted Trees (GBTs) learning algorithm for regression. It supports both continuous and categoricalfeatures.
New in version 1.4.0.
3.3. ML 847
pyspark Documentation, Release master
Examples
>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId")>>> gbt.setMaxIter(5)GBTRegressor...>>> gbt.setMinWeightFractionPerNode(0.049)GBTRegressor...>>> gbt.getMaxIter()5>>> print(gbt.getImpurity())variance>>> print(gbt.getFeatureSubsetStrategy())all>>> model = gbt.fit(df)>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> model.numFeatures1>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictLeaf(test0.head().features)DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.leafIdDenseVector([0.0, 0.0, 0.0, 0.0, 0.0])>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> gbtr_path = temp_path + "gbtr">>> gbt.save(gbtr_path)>>> gbt2 = GBTRegressor.load(gbtr_path)>>> gbt2.getMaxDepth()2>>> model_path = temp_path + "gbtr_model">>> model.save(model_path)>>> model2 = GBTRegressionModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.treeWeights == model2.treeWeightsTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.trees[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],... ["label", "features"])
(continues on next page)
848 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> model.evaluateEachIteration(validation, "squared")[0.0, 0.0, 0.0, 0.0, 0.0]>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")>>> gbt.getValidationIndicatorCol()'validationIndicator'>>> gbt.getValidationTol()0.01
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.continues on next page
3.3. ML 849
pyspark Documentation, Release master
Table 305 – continued from previous pagegetOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default
value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setLossType(value) Sets the value of lossType.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxIter(value) Sets the value of maxIter.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of
minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for Gradient Boosted Tree Regression.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setStepSize(value) Sets the value of stepSize.setSubsamplingRate(value) Sets the value of subsamplingRate.setValidationIndicatorCol(value) Sets the value of validationIndicatorCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
850 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
cacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpuritylabelColleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypesvalidationIndicatorColvalidationTolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
3.3. ML 851
pyspark Documentation, Release master
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
852 Chapter 3. API Reference
pyspark Documentation, Release master
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getLossType()Gets the value of lossType or its default value.
New in version 1.4.0.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getStepSize()Gets the value of stepSize or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.
getValidationTol()Gets the value of validationTol or its default value.
New in version 3.0.0.
getWeightCol()Gets the value of weightCol or its default value.
3.3. ML 853
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCacheNodeIds(value)Sets the value of cacheNodeIds.
New in version 1.4.0.
setCheckpointInterval(value)Sets the value of checkpointInterval.
New in version 1.4.0.
setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setImpurity(value)Sets the value of impurity .
New in version 1.4.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setLossType(value)Sets the value of lossType.
New in version 1.4.0.
setMaxBins(value)Sets the value of maxBins.
New in version 1.4.0.
854 Chapter 3. API Reference
pyspark Documentation, Release master
setMaxDepth(value)Sets the value of maxDepth.
New in version 1.4.0.
setMaxIter(value)Sets the value of maxIter.
New in version 1.4.0.
setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.
New in version 1.4.0.
setMinInfoGain(value)Sets the value of minInfoGain.
New in version 1.4.0.
setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.
New in version 1.4.0.
setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.
New in version 3.0.0.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10,lossType="squared", maxIter=20, stepSize=0.1, seed=None, impurity="variance", fea-tureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None, leafCol="",minWeightFractionPerNode=0.0, weightCol=None)
Sets params for Gradient Boosted Tree Regression.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
New in version 1.4.0.
setStepSize(value)Sets the value of stepSize.
New in version 1.4.0.
setSubsamplingRate(value)Sets the value of subsamplingRate.
New in version 1.4.0.
setValidationIndicatorCol(value)Sets the value of validationIndicatorCol.
New in version 3.0.0.
3.3. ML 855
pyspark Documentation, Release master
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: squared, absolute')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['variance']
supportedLossTypes = ['squared', 'absolute']
validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')
validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
856 Chapter 3. API Reference
pyspark Documentation, Release master
GBTRegressionModel
class pyspark.ml.regression.GBTRegressionModel(java_model=None)Model fitted by GBTRegressor.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluateEachIteration(dataset, loss) Method to compute error or loss for every iterationof gradient boosting.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.
continues on next page
3.3. ML 857
pyspark Documentation, Release master
Table 307 – continued from previous pagegetStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default
value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the
feature vector.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
cacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.impuritylabelColleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNode
continues on next page
858 Chapter 3. API Reference
pyspark Documentation, Release master
Table 308 – continued from previous pagenumFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypestoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the
ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.validationIndicatorColvalidationTolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluateEachIteration(dataset, loss)Method to compute error or loss for every iteration of gradient boosting.
New in version 2.4.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on, where datasetis an instance of pyspark.sql.DataFrame
loss [str] The loss function used to compute error. Supported options: squared, absolute
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra
3.3. ML 859
pyspark Documentation, Release master
values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getLossType()Gets the value of lossType or its default value.
New in version 1.4.0.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
860 Chapter 3. API Reference
pyspark Documentation, Release master
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getStepSize()Gets the value of stepSize or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.
getValidationTol()Gets the value of validationTol or its default value.
New in version 3.0.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
3.3. ML 861
pyspark Documentation, Release master
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureImportancesEstimate of the importance of each feature.
Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.
New in version 2.0.0.
Examples
DecisionTreeRegressionModel.featureImportances
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
getNumTreesNumber of trees in ensemble.
New in version 2.0.0.
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: squared, absolute')
862 Chapter 3. API Reference
pyspark Documentation, Release master
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['variance']
supportedLossTypes = ['squared', 'absolute']
toDebugStringFull description of model.
New in version 2.0.0.
totalNumNodesTotal number of nodes, summed over all trees in the ensemble.
New in version 2.0.0.
treeWeightsReturn the weights for each tree
New in version 1.5.0.
treesTrees in this ensemble. Warning: These have null parent Estimators.
New in version 2.0.0.
validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')
validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 863
pyspark Documentation, Release master
GeneralizedLinearRegression
class pyspark.ml.regression.GeneralizedLinearRegression(*args, **kwargs)Generalized Linear Regression.
Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function)and a description of the error distribution (family). It supports “gaussian”, “binomial”, “poisson”, “gamma” and“tweedie” as family. Valid link functions for each family is listed below. The first link function of each familyis the default one.
• “gaussian” -> “identity”, “log”, “inverse”
• “binomial” -> “logit”, “probit”, “cloglog”
• “poisson” -> “log”, “identity”, “sqrt”
• “gamma” -> “inverse”, “identity”, “log”
• “tweedie” -> power link function specified through “linkPower”. The default link power in the tweediefamily is 1 - variancePower.
New in version 2.0.0.
Notes
For more information see Wikipedia page on GLM
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(0.0, 0.0)),... (1.0, Vectors.dense(1.0, 2.0)),... (2.0, Vectors.dense(0.0, 0.0)),... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity",→˓linkPredictionCol="p")>>> glr.setRegParam(0.1)GeneralizedLinearRegression...>>> glr.getRegParam()0.1>>> glr.clear(glr.regParam)>>> glr.setMaxIter(10)GeneralizedLinearRegression...>>> glr.getMaxIter()10>>> glr.clear(glr.maxIter)>>> model = glr.fit(df)>>> model.setFeaturesCol("features")GeneralizedLinearRegressionModel...>>> model.getMaxIter()25>>> model.getAggregationDepth()2>>> transformed = model.transform(df)>>> abs(transformed.head().prediction - 1.5) < 0.001True
(continues on next page)
864 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> abs(transformed.head().p - 1.5) < 0.001True>>> model.coefficientsDenseVector([1.5..., -1.0...])>>> model.numFeatures2>>> abs(model.intercept - 1.5) < 0.001True>>> glr_path = temp_path + "/glr">>> glr.save(glr_path)>>> glr2 = GeneralizedLinearRegression.load(glr_path)>>> glr.getFamily() == glr2.getFamily()True>>> model_path = temp_path + "/glr_model">>> model.save(model_path)>>> model2 = GeneralizedLinearRegressionModel.load(model_path)>>> model.intercept == model2.interceptTrue>>> model.coefficients[0] == model2.coefficients[0]True>>> model.transform(df).take(1) == model2.transform(df).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLink() Gets the value of link or its default value.
continues on next page
3.3. ML 865
pyspark Documentation, Release master
Table 309 – continued from previous pagegetLinkPower() Gets the value of linkPower or its default value.getLinkPredictionCol() Gets the value of linkPredictionCol or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOffsetCol() Gets the value of offsetCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getTol() Gets the value of tol or its default value.getVariancePower() Gets the value of variancePower or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setFamily(value) Sets the value of family .setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setLink(value) Sets the value of link.setLinkPower(value) Sets the value of linkPower.setLinkPredictionCol(value) Sets the value of linkPredictionCol.setMaxIter(value) Sets the value of maxIter.setOffsetCol(value) Sets the value of offsetCol.setParams(self, \*[, labelCol, featuresCol, . . . ]) Sets params for generalized linear regression.setPredictionCol(value) Sets the value of predictionCol.setRegParam(value) Sets the value of regParam.setSolver(value) Sets the value of solver.setTol(value) Sets the value of tol.setVariancePower(value) Sets the value of variancePower.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
866 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
aggregationDepthfamilyfeaturesColfitInterceptlabelCollinklinkPowerlinkPredictionColmaxIteroffsetColparams Returns all params ordered by name.predictionColregParamsolvertolvariancePowerweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
3.3. ML 867
pyspark Documentation, Release master
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getFamily()Gets the value of family or its default value.
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getLink()Gets the value of link or its default value.
New in version 2.0.0.
getLinkPower()Gets the value of linkPower or its default value.
New in version 2.2.0.
getLinkPredictionCol()Gets the value of linkPredictionCol or its default value.
New in version 2.0.0.
868 Chapter 3. API Reference
pyspark Documentation, Release master
getMaxIter()Gets the value of maxIter or its default value.
getOffsetCol()Gets the value of offsetCol or its default value.
New in version 2.3.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSolver()Gets the value of solver or its default value.
getTol()Gets the value of tol or its default value.
getVariancePower()Gets the value of variancePower or its default value.
New in version 2.2.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAggregationDepth(value)Sets the value of aggregationDepth.
New in version 3.0.0.
3.3. ML 869
pyspark Documentation, Release master
setFamily(value)Sets the value of family .
New in version 2.0.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setFitIntercept(value)Sets the value of fitIntercept.
New in version 2.0.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLink(value)Sets the value of link.
New in version 2.0.0.
setLinkPower(value)Sets the value of linkPower.
New in version 2.2.0.
setLinkPredictionCol(value)Sets the value of linkPredictionCol.
New in version 2.0.0.
setMaxIter(value)Sets the value of maxIter.
New in version 2.0.0.
setOffsetCol(value)Sets the value of offsetCol.
New in version 2.3.0.
setParams(self, \*, labelCol="label", featuresCol="features", predictionCol="prediction", fam-ily="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0,weightCol=None, solver="irls", linkPredictionCol=None, variancePower=0.0,linkPower=None, offsetCol=None, aggregationDepth=2)
Sets params for generalized linear regression.
New in version 2.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setRegParam(value)Sets the value of regParam.
New in version 2.0.0.
setSolver(value)Sets the value of solver.
New in version 2.0.0.
870 Chapter 3. API Reference
pyspark Documentation, Release master
setTol(value)Sets the value of tol.
New in version 2.0.0.
setVariancePower(value)Sets the value of variancePower.
New in version 2.2.0.
setWeightCol(value)Sets the value of weightCol.
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
family = Param(parent='undefined', name='family', doc='The name of family which is a description of the error distribution to be used in the model. Supported options: gaussian (default), binomial, poisson, gamma and tweedie.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
link = Param(parent='undefined', name='link', doc='The name of link function which provides the relationship between the linear predictor and the mean of the distribution function. Supported options: identity, log, inverse, logit, probit, cloglog and sqrt.')
linkPower = Param(parent='undefined', name='linkPower', doc='The index in the power link function. Only applicable to the Tweedie family.')
linkPredictionCol = Param(parent='undefined', name='linkPredictionCol', doc='link prediction (linear predictor) column name')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
offsetCol = Param(parent='undefined', name='offsetCol', doc='The offset column name. If this is not set or empty, we treat all instance offsets as 0.0')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: irls.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
variancePower = Param(parent='undefined', name='variancePower', doc='The power in the variance function of the Tweedie distribution which characterizes the relationship between the variance and mean of the distribution. Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 871
pyspark Documentation, Release master
GeneralizedLinearRegressionModel
class pyspark.ml.regression.GeneralizedLinearRegressionModel(java_model=None)Model fitted by GeneralizedLinearRegression.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLink() Gets the value of link or its default value.getLinkPower() Gets the value of linkPower or its default value.getLinkPredictionCol() Gets the value of linkPredictionCol or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOffsetCol() Gets the value of offsetCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getTol() Gets the value of tol or its default value.getVariancePower() Gets the value of variancePower or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.continues on next page
872 Chapter 3. API Reference
pyspark Documentation, Release master
Table 311 – continued from previous pageisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLinkPredictionCol(value) Sets the value of linkPredictionCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthcoefficients Model coefficients.familyfeaturesColfitIntercepthasSummary Indicates whether a training summary exists for this
model instance.intercept Model intercept.labelCollinklinkPowerlinkPredictionColmaxIternumFeatures Returns the number of features the model was trained
on.offsetColparams Returns all params ordered by name.predictionColregParamsolversummary Gets summary (e.g.tolvariancePowerweightCol
3.3. ML 873
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 2.0.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on, where datasetis an instance of pyspark.sql.DataFrame
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getFamily()Gets the value of family or its default value.
New in version 2.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
874 Chapter 3. API Reference
pyspark Documentation, Release master
getLink()Gets the value of link or its default value.
New in version 2.0.0.
getLinkPower()Gets the value of linkPower or its default value.
New in version 2.2.0.
getLinkPredictionCol()Gets the value of linkPredictionCol or its default value.
New in version 2.0.0.
getMaxIter()Gets the value of maxIter or its default value.
getOffsetCol()Gets the value of offsetCol or its default value.
New in version 2.3.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSolver()Gets the value of solver or its default value.
getTol()Gets the value of tol or its default value.
getVariancePower()Gets the value of variancePower or its default value.
New in version 2.2.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
3.3. ML 875
pyspark Documentation, Release master
predict(value)Predict label for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLinkPredictionCol(value)Sets the value of linkPredictionCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
coefficientsModel coefficients.
New in version 2.0.0.
family = Param(parent='undefined', name='family', doc='The name of family which is a description of the error distribution to be used in the model. Supported options: gaussian (default), binomial, poisson, gamma and tweedie.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
876 Chapter 3. API Reference
pyspark Documentation, Release master
interceptModel intercept.
New in version 2.0.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
link = Param(parent='undefined', name='link', doc='The name of link function which provides the relationship between the linear predictor and the mean of the distribution function. Supported options: identity, log, inverse, logit, probit, cloglog and sqrt.')
linkPower = Param(parent='undefined', name='linkPower', doc='The index in the power link function. Only applicable to the Tweedie family.')
linkPredictionCol = Param(parent='undefined', name='linkPredictionCol', doc='link prediction (linear predictor) column name')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
offsetCol = Param(parent='undefined', name='offsetCol', doc='The offset column name. If this is not set or empty, we treat all instance offsets as 0.0')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: irls.')
summaryGets summary (e.g. residuals, deviance, pValues) of model on training set. An exception is thrown iftrainingSummary is None.
New in version 2.0.0.
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
variancePower = Param(parent='undefined', name='variancePower', doc='The power in the variance function of the Tweedie distribution which characterizes the relationship between the variance and mean of the distribution. Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
GeneralizedLinearRegressionSummary
class pyspark.ml.regression.GeneralizedLinearRegressionSummary(java_obj=None)Generalized linear regression results evaluated on a dataset.
New in version 2.0.0.
Methods
residuals([residualsType]) Get the residuals of the fitted model by type.
3.3. ML 877
pyspark Documentation, Release master
Attributes
aic Akaike’s “An Information Criterion”(AIC) for thefitted model.
degreesOfFreedom Degrees of freedom.deviance The deviance for the fitted model.dispersion The dispersion of the fitted model.nullDeviance The deviance for the null model.numInstances Number of instances in DataFrame predictions.predictionCol Field in predictions which gives the predicted
value of each instance.predictions Predictions output by the model’s transform method.rank The numeric rank of the fitted linear model.residualDegreeOfFreedom The residual degrees of freedom.residualDegreeOfFreedomNull The residual degrees of freedom for the null model.
Methods Documentation
residuals(residualsType='deviance')Get the residuals of the fitted model by type.
New in version 2.0.0.
Parameters
residualsType [str, optional] The type of residuals which should be returned. Supportedoptions: deviance (default), pearson, working, and response.
Attributes Documentation
aicAkaike’s “An Information Criterion”(AIC) for the fitted model.
New in version 2.0.0.
degreesOfFreedomDegrees of freedom.
New in version 2.0.0.
devianceThe deviance for the fitted model.
New in version 2.0.0.
dispersionThe dispersion of the fitted model. It is taken as 1.0 for the “binomial” and “poisson” families, andotherwise estimated by the residual Pearson’s Chi-Squared statistic (which is defined as sum of the squaresof the Pearson residuals) divided by the residual degrees of freedom.
New in version 2.0.0.
nullDevianceThe deviance for the null model.
New in version 2.0.0.
878 Chapter 3. API Reference
pyspark Documentation, Release master
numInstancesNumber of instances in DataFrame predictions.
New in version 2.2.0.
predictionColField in predictions which gives the predicted value of each instance. This is set to a new columnname if the original model’s predictionCol is not set.
New in version 2.0.0.
predictionsPredictions output by the model’s transform method.
New in version 2.0.0.
rankThe numeric rank of the fitted linear model.
New in version 2.0.0.
residualDegreeOfFreedomThe residual degrees of freedom.
New in version 2.0.0.
residualDegreeOfFreedomNullThe residual degrees of freedom for the null model.
New in version 2.0.0.
GeneralizedLinearRegressionTrainingSummary
class pyspark.ml.regression.GeneralizedLinearRegressionTrainingSummary(java_obj=None)Generalized linear regression training results.
New in version 2.0.0.
Methods
residuals([residualsType]) Get the residuals of the fitted model by type.
Attributes
aic Akaike’s “An Information Criterion”(AIC) for thefitted model.
coefficientStandardErrors Standard error of estimated coefficients and inter-cept.
degreesOfFreedom Degrees of freedom.deviance The deviance for the fitted model.dispersion The dispersion of the fitted model.nullDeviance The deviance for the null model.numInstances Number of instances in DataFrame predictions.numIterations Number of training iterations.
continues on next page
3.3. ML 879
pyspark Documentation, Release master
Table 316 – continued from previous pagepValues Two-sided p-value of estimated coefficients and in-
tercept.predictionCol Field in predictions which gives the predicted
value of each instance.predictions Predictions output by the model’s transform method.rank The numeric rank of the fitted linear model.residualDegreeOfFreedom The residual degrees of freedom.residualDegreeOfFreedomNull The residual degrees of freedom for the null model.solver The numeric solver used for training.tValues T-statistic of estimated coefficients and intercept.
Methods Documentation
residuals(residualsType='deviance')Get the residuals of the fitted model by type.
New in version 2.0.0.
Parameters
residualsType [str, optional] The type of residuals which should be returned. Supportedoptions: deviance (default), pearson, working, and response.
Attributes Documentation
aicAkaike’s “An Information Criterion”(AIC) for the fitted model.
New in version 2.0.0.
coefficientStandardErrorsStandard error of estimated coefficients and intercept.
If GeneralizedLinearRegression.fitIntercept is set to True, then the last element returnedcorresponds to the intercept.
New in version 2.0.0.
degreesOfFreedomDegrees of freedom.
New in version 2.0.0.
devianceThe deviance for the fitted model.
New in version 2.0.0.
dispersionThe dispersion of the fitted model. It is taken as 1.0 for the “binomial” and “poisson” families, andotherwise estimated by the residual Pearson’s Chi-Squared statistic (which is defined as sum of the squaresof the Pearson residuals) divided by the residual degrees of freedom.
New in version 2.0.0.
nullDevianceThe deviance for the null model.
New in version 2.0.0.
880 Chapter 3. API Reference
pyspark Documentation, Release master
numInstancesNumber of instances in DataFrame predictions.
New in version 2.2.0.
numIterationsNumber of training iterations.
New in version 2.0.0.
pValuesTwo-sided p-value of estimated coefficients and intercept.
If GeneralizedLinearRegression.fitIntercept is set to True, then the last element returnedcorresponds to the intercept.
New in version 2.0.0.
predictionColField in predictions which gives the predicted value of each instance. This is set to a new columnname if the original model’s predictionCol is not set.
New in version 2.0.0.
predictionsPredictions output by the model’s transform method.
New in version 2.0.0.
rankThe numeric rank of the fitted linear model.
New in version 2.0.0.
residualDegreeOfFreedomThe residual degrees of freedom.
New in version 2.0.0.
residualDegreeOfFreedomNullThe residual degrees of freedom for the null model.
New in version 2.0.0.
solverThe numeric solver used for training.
New in version 2.0.0.
tValuesT-statistic of estimated coefficients and intercept.
If GeneralizedLinearRegression.fitIntercept is set to True, then the last element returnedcorresponds to the intercept.
New in version 2.0.0.
3.3. ML 881
pyspark Documentation, Release master
IsotonicRegression
class pyspark.ml.regression.IsotonicRegression(*args, **kwargs)Currently implemented using parallelized pool adjacent violators algorithm. Only univariate (single feature)algorithm supported.
New in version 1.6.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> ir = IsotonicRegression()>>> model = ir.fit(df)>>> model.setFeaturesCol("features")IsotonicRegressionModel...>>> model.numFeatures1>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.transform(test0).head().prediction0.0>>> model.predict(test0.head().features[model.getFeatureIndex()])0.0>>> model.boundariesDenseVector([0.0, 1.0])>>> ir_path = temp_path + "/ir">>> ir.save(ir_path)>>> ir2 = IsotonicRegression.load(ir_path)>>> ir2.getIsotonic()True>>> model_path = temp_path + "/ir_model">>> model.save(model_path)>>> model2 = IsotonicRegressionModel.load(model_path)>>> model.boundaries == model2.boundariesTrue>>> model.predictions == model2.predictionsTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
882 Chapter 3. API Reference
pyspark Documentation, Release master
Table 317 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFeatureIndex() Gets the value of featureIndex or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getIsotonic() Gets the value of isotonic or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeatureIndex(value) Sets the value of featureIndex.setFeaturesCol(value) Sets the value of featuresCol.setIsotonic(value) Sets the value of isotonic.setLabelCol(value) Sets the value of labelCol.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, label-
Col=”label”, predictionCol=”prediction”, weight-Col=None, isotonic=True, featureIndex=0): Set theparams for IsotonicRegression.
setPredictionCol(value) Sets the value of predictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
3.3. ML 883
pyspark Documentation, Release master
Attributes
featureIndexfeaturesColisotoniclabelColparams Returns all params ordered by name.predictionColweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
884 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFeatureIndex()Gets the value of featureIndex or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getIsotonic()Gets the value of isotonic or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
3.3. ML 885
pyspark Documentation, Release master
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeatureIndex(value)Sets the value of featureIndex.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 1.6.0.
setIsotonic(value)Sets the value of isotonic.
setLabelCol(value)Sets the value of labelCol.
New in version 1.6.0.
setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, weight-Col=None, isotonic=True, featureIndex=0): Set the params for IsotonicRegression.
setPredictionCol(value)Sets the value of predictionCol.
New in version 1.6.0.
setWeightCol(value)Sets the value of weightCol.
New in version 1.6.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
featureIndex = Param(parent='undefined', name='featureIndex', doc='The index of the feature if featuresCol is a vector column, no effect otherwise.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
isotonic = Param(parent='undefined', name='isotonic', doc='whether the output sequence should be isotonic/increasing (true) orantitonic/decreasing (false).')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
886 Chapter 3. API Reference
pyspark Documentation, Release master
IsotonicRegressionModel
class pyspark.ml.regression.IsotonicRegressionModel(java_model=None)Model fitted by IsotonicRegression.
New in version 1.6.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFeatureIndex() Gets the value of featureIndex or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getIsotonic() Gets the value of isotonic or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeatureIndex(value) Sets the value of featureIndex.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.continues on next page
3.3. ML 887
pyspark Documentation, Release master
Table 319 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.
Attributes
boundaries Boundaries in increasing order for which predictionsare known.
featureIndexfeaturesColisotoniclabelColnumFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColpredictions Predictions associated with the boundaries at the
same index, monotone because of isotonic regres-sion.
weightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
888 Chapter 3. API Reference
pyspark Documentation, Release master
getFeatureIndex()Gets the value of featureIndex or its default value.
getFeaturesCol()Gets the value of featuresCol or its default value.
getIsotonic()Gets the value of isotonic or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeatureIndex(value)Sets the value of featureIndex.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
3.3. ML 889
pyspark Documentation, Release master
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
boundariesBoundaries in increasing order for which predictions are known.
New in version 1.6.0.
featureIndex = Param(parent='undefined', name='featureIndex', doc='The index of the feature if featuresCol is a vector column, no effect otherwise.')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
isotonic = Param(parent='undefined', name='isotonic', doc='whether the output sequence should be isotonic/increasing (true) orantitonic/decreasing (false).')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 3.0.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
predictionsPredictions associated with the boundaries at the same index, monotone because of isotonic regression.
New in version 1.6.0.
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
890 Chapter 3. API Reference
pyspark Documentation, Release master
LinearRegression
class pyspark.ml.regression.LinearRegression(*args, **kwargs)Linear regression.
The learning objective is to minimize the specified loss function, with regularization. This supports two kindsof loss:
• squaredError (a.k.a squared loss)
• huber (a hybrid of squared error for relatively small errors and absolute error for relatively large ones, andwe estimate the scale parameter from training data)
This supports multiple types of regularization:
• none (a.k.a. ordinary least squares)
• L2 (ridge regression)
• L1 (Lasso)
• L2 + L1 (elastic net)
New in version 1.4.0.
Notes
Fitting with huber loss only supports none and L2 regularization.
Examples
>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, 2.0, Vectors.dense(1.0)),... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])>>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight")>>> lr.setMaxIter(5)LinearRegression...>>> lr.getMaxIter()5>>> lr.setRegParam(0.1)LinearRegression...>>> lr.getRegParam()0.1>>> lr.setRegParam(0.0)LinearRegression...>>> model = lr.fit(df)>>> model.setFeaturesCol("features")LinearRegressionModel...>>> model.setPredictionCol("newPrediction")LinearRegressionModel...>>> model.getMaxIter()5>>> model.getMaxBlockSizeInMB()0.0>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001True
(continues on next page)
3.3. ML 891
pyspark Documentation, Release master
(continued from previous page)
>>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001True>>> abs(model.coefficients[0] - 1.0) < 0.001True>>> abs(model.intercept - 0.0) < 0.001True>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> abs(model.transform(test1).head().newPrediction - 1.0) < 0.001True>>> lr.setParams("vector")Traceback (most recent call last):
...TypeError: Method setParams forces keyword arguments.>>> lr_path = temp_path + "/lr">>> lr.save(lr_path)>>> lr2 = LinearRegression.load(lr_path)>>> lr2.getMaxIter()5>>> model_path = temp_path + "/lr_model">>> model.save(model_path)>>> model2 = LinearRegressionModel.load(model_path)>>> model.coefficients[0] == model2.coefficients[0]True>>> model.intercept == model2.interceptTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.numFeatures1>>> model.write().format("pmml").save(model_path + "_2")
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
continues on next page
892 Chapter 3. API Reference
pyspark Documentation, Release master
Table 321 – continued from previous pagefitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map
in paramMaps.getAggregationDepth() Gets the value of aggregationDepth or its default
value.getElasticNetParam() Gets the value of elasticNetParam or its default
value.getEpsilon() Gets the value of epsilon or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLoss() Gets the value of loss or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getStandardization() Gets the value of standardization or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setElasticNetParam(value) Sets the value of elasticNetParam.setEpsilon(value) Sets the value of epsilon.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setLoss(value) Sets the value of loss.setMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for linear regression.setPredictionCol(value) Sets the value of predictionCol.setRegParam(value) Sets the value of regParam.setSolver(value) Sets the value of solver.setStandardization(value) Sets the value of standardization.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.
continues on next page
3.3. ML 893
pyspark Documentation, Release master
Table 321 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.
Attributes
aggregationDepthelasticNetParamepsilonfeaturesColfitInterceptlabelCollossmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColregParamsolverstandardizationtolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
894 Chapter 3. API Reference
pyspark Documentation, Release master
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getElasticNetParam()Gets the value of elasticNetParam or its default value.
getEpsilon()Gets the value of epsilon or its default value.
New in version 2.3.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getLoss()Gets the value of loss or its default value.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
3.3. ML 895
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSolver()Gets the value of solver or its default value.
getStandardization()Gets the value of standardization or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setAggregationDepth(value)Sets the value of aggregationDepth.
setElasticNetParam(value)Sets the value of elasticNetParam.
setEpsilon(value)Sets the value of epsilon.
New in version 2.3.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
896 Chapter 3. API Reference
pyspark Documentation, Release master
setFitIntercept(value)Sets the value of fitIntercept.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLoss(value)Sets the value of loss.
setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.
New in version 3.1.0.
setMaxIter(value)Sets the value of maxIter.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", max-Iter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardiza-tion=True, solver="auto", weightCol=None, aggregationDepth=2, loss="squaredError",epsilon=1.35, maxBlockSizeInMB=0.0)
Sets params for linear regression.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setRegParam(value)Sets the value of regParam.
setSolver(value)Sets the value of solver.
setStandardization(value)Sets the value of standardization.
setTol(value)Sets the value of tol.
setWeightCol(value)Sets the value of weightCol.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')
epsilon = Param(parent='undefined', name='epsilon', doc='The shape parameter to control the amount of robustness. Must be > 1.0. Only valid when loss is huber')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
3.3. ML 897
pyspark Documentation, Release master
loss = Param(parent='undefined', name='loss', doc='The loss function to be optimized. Supported options: squaredError, huber.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: auto, normal, l-bfgs.')
standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
LinearRegressionModel
class pyspark.ml.regression.LinearRegressionModel(java_model=None)Model fitted by LinearRegression.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.
getElasticNetParam() Gets the value of elasticNetParam or its defaultvalue.
getEpsilon() Gets the value of epsilon or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.
continues on next page
898 Chapter 3. API Reference
pyspark Documentation, Release master
Table 323 – continued from previous pagegetLoss() Gets the value of loss or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default
value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getStandardization() Gets the value of standardization or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an GeneralMLWriter instance for this ML
instance.
Attributes
aggregationDepthcoefficients Model coefficients.elasticNetParamepsilonfeaturesColfitIntercepthasSummary Indicates whether a training summary exists for this
model instance.intercept Model intercept.labelCollossmaxBlockSizeInMBmaxIter
continues on next page
3.3. ML 899
pyspark Documentation, Release master
Table 324 – continued from previous pagenumFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColregParamscale The value by which ‖𝑦−𝑋 ′𝑤‖ is scaled down when
loss is “huber”, otherwise 1.0.solverstandardizationsummary Gets summary (e.g.tolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset)Evaluates the model on a test dataset.
New in version 2.0.0.
Parameters
dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on, where datasetis an instance of pyspark.sql.DataFrame
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
900 Chapter 3. API Reference
pyspark Documentation, Release master
getAggregationDepth()Gets the value of aggregationDepth or its default value.
getElasticNetParam()Gets the value of elasticNetParam or its default value.
getEpsilon()Gets the value of epsilon or its default value.
New in version 2.3.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getLabelCol()Gets the value of labelCol or its default value.
getLoss()Gets the value of loss or its default value.
getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSolver()Gets the value of solver or its default value.
getStandardization()Gets the value of standardization or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
3.3. ML 901
pyspark Documentation, Release master
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an GeneralMLWriter instance for this ML instance.
Attributes Documentation
aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')
coefficientsModel coefficients.
New in version 2.0.0.
elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')
epsilon = Param(parent='undefined', name='epsilon', doc='The shape parameter to control the amount of robustness. Must be > 1.0. Only valid when loss is huber')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
902 Chapter 3. API Reference
pyspark Documentation, Release master
hasSummaryIndicates whether a training summary exists for this model instance.
New in version 2.1.0.
interceptModel intercept.
New in version 1.4.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
loss = Param(parent='undefined', name='loss', doc='The loss function to be optimized. Supported options: squaredError, huber.')
maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
scaleThe value by which ‖𝑦 −𝑋 ′𝑤‖ is scaled down when loss is “huber”, otherwise 1.0.
New in version 2.3.0.
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: auto, normal, l-bfgs.')
standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')
summaryGets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is thrown iftrainingSummary is None.
New in version 2.0.0.
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
LinearRegressionSummary
class pyspark.ml.regression.LinearRegressionSummary(java_obj=None)Linear regression results evaluated on a dataset.
New in version 2.0.0.
3.3. ML 903
pyspark Documentation, Release master
Attributes
coefficientStandardErrors Standard error of estimated coefficients and inter-cept.
degreesOfFreedom Degrees of freedom.devianceResiduals The weighted residuals, the usual residuals rescaled
by the square root of the instance weights.explainedVariance Returns the explained variance regression score.featuresCol Field in “predictions” which gives the features of
each instance as a vector.labelCol Field in “predictions” which gives the true label of
each instance.meanAbsoluteError Returns the mean absolute error, which is a risk func-
tion corresponding to the expected value of the abso-lute error loss or l1-norm loss.
meanSquaredError Returns the mean squared error, which is a riskfunction corresponding to the expected value of thesquared error loss or quadratic loss.
numInstances Number of instances in DataFrame predictionspValues Two-sided p-value of estimated coefficients and in-
tercept.predictionCol Field in “predictions” which gives the predicted
value of the label at each instance.predictions Dataframe outputted by the model’s transform
method.r2 Returns R^2, the coefficient of determination.r2adj Returns Adjusted R^2, the adjusted coefficient of de-
termination.residuals Residuals (label - predicted value)rootMeanSquaredError Returns the root mean squared error, which is defined
as the square root of the mean squared error.tValues T-statistic of estimated coefficients and intercept.
Attributes Documentation
coefficientStandardErrorsStandard error of estimated coefficients and intercept. This value is only available when using the “normal”solver.
If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.
New in version 2.0.0.
See also:
LinearRegression.solver
degreesOfFreedomDegrees of freedom.
New in version 2.2.0.
904 Chapter 3. API Reference
pyspark Documentation, Release master
devianceResidualsThe weighted residuals, the usual residuals rescaled by the square root of the instance weights.
New in version 2.0.0.
explainedVarianceReturns the explained variance regression score. explainedVariance = 1− 𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦−𝑦)
𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦)
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
For additional information see Explained variation on Wikipedia
New in version 2.0.0.
featuresColField in “predictions” which gives the features of each instance as a vector.
New in version 2.0.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 2.0.0.
meanAbsoluteErrorReturns the mean absolute error, which is a risk function corresponding to the expected value of the abso-lute error loss or l1-norm loss.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
New in version 2.0.0.
meanSquaredErrorReturns the mean squared error, which is a risk function corresponding to the expected value of the squarederror loss or quadratic loss.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
New in version 2.0.0.
numInstancesNumber of instances in DataFrame predictions
New in version 2.0.0.
pValuesTwo-sided p-value of estimated coefficients and intercept. This value is only available when using the“normal” solver.
If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.
3.3. ML 905
pyspark Documentation, Release master
New in version 2.0.0.
See also:
LinearRegression.solver
predictionColField in “predictions” which gives the predicted value of the label at each instance.
New in version 2.0.0.
predictionsDataframe outputted by the model’s transform method.
New in version 2.0.0.
r2Returns R^2, the coefficient of determination.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
See also Wikipedia coefficient of determination
New in version 2.0.0.
r2adjReturns Adjusted R^2, the adjusted coefficient of determination.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
Wikipedia coefficient of determination, Adjusted R^2
New in version 2.4.0.
residualsResiduals (label - predicted value)
New in version 2.0.0.
rootMeanSquaredErrorReturns the root mean squared error, which is defined as the square root of the mean squared error.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
New in version 2.0.0.
tValuesT-statistic of estimated coefficients and intercept. This value is only available when using the “normal”solver.
906 Chapter 3. API Reference
pyspark Documentation, Release master
If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.
New in version 2.0.0.
See also:
LinearRegression.solver
LinearRegressionTrainingSummary
class pyspark.ml.regression.LinearRegressionTrainingSummary(java_obj=None)Linear regression training results. Currently, the training summary ignores the training weights except for theobjective trace.
New in version 2.0.0.
Attributes
coefficientStandardErrors Standard error of estimated coefficients and inter-cept.
degreesOfFreedom Degrees of freedom.devianceResiduals The weighted residuals, the usual residuals rescaled
by the square root of the instance weights.explainedVariance Returns the explained variance regression score.featuresCol Field in “predictions” which gives the features of
each instance as a vector.labelCol Field in “predictions” which gives the true label of
each instance.meanAbsoluteError Returns the mean absolute error, which is a risk func-
tion corresponding to the expected value of the abso-lute error loss or l1-norm loss.
meanSquaredError Returns the mean squared error, which is a riskfunction corresponding to the expected value of thesquared error loss or quadratic loss.
numInstances Number of instances in DataFrame predictionsobjectiveHistory Objective function (scaled loss + regularization) at
each iteration.pValues Two-sided p-value of estimated coefficients and in-
tercept.predictionCol Field in “predictions” which gives the predicted
value of the label at each instance.predictions Dataframe outputted by the model’s transform
method.r2 Returns R^2, the coefficient of determination.r2adj Returns Adjusted R^2, the adjusted coefficient of de-
termination.residuals Residuals (label - predicted value)rootMeanSquaredError Returns the root mean squared error, which is defined
as the square root of the mean squared error.tValues T-statistic of estimated coefficients and intercept.totalIterations Number of training iterations until termination.
3.3. ML 907
pyspark Documentation, Release master
Attributes Documentation
coefficientStandardErrorsStandard error of estimated coefficients and intercept. This value is only available when using the “normal”solver.
If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.
New in version 2.0.0.
See also:
LinearRegression.solver
degreesOfFreedomDegrees of freedom.
New in version 2.2.0.
devianceResidualsThe weighted residuals, the usual residuals rescaled by the square root of the instance weights.
New in version 2.0.0.
explainedVarianceReturns the explained variance regression score. explainedVariance = 1− 𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦−𝑦)
𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦)
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
For additional information see Explained variation on Wikipedia
New in version 2.0.0.
featuresColField in “predictions” which gives the features of each instance as a vector.
New in version 2.0.0.
labelColField in “predictions” which gives the true label of each instance.
New in version 2.0.0.
meanAbsoluteErrorReturns the mean absolute error, which is a risk function corresponding to the expected value of the abso-lute error loss or l1-norm loss.
908 Chapter 3. API Reference
pyspark Documentation, Release master
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
New in version 2.0.0.
meanSquaredErrorReturns the mean squared error, which is a risk function corresponding to the expected value of the squarederror loss or quadratic loss.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
New in version 2.0.0.
numInstancesNumber of instances in DataFrame predictions
New in version 2.0.0.
objectiveHistoryObjective function (scaled loss + regularization) at each iteration. This value is only available when usingthe “l-bfgs” solver.
New in version 2.0.0.
See also:
LinearRegression.solver
pValuesTwo-sided p-value of estimated coefficients and intercept. This value is only available when using the“normal” solver.
If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.
New in version 2.0.0.
See also:
LinearRegression.solver
predictionColField in “predictions” which gives the predicted value of the label at each instance.
New in version 2.0.0.
predictionsDataframe outputted by the model’s transform method.
New in version 2.0.0.
r2Returns R^2, the coefficient of determination.
3.3. ML 909
pyspark Documentation, Release master
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
See also Wikipedia coefficient of determination
New in version 2.0.0.
r2adjReturns Adjusted R^2, the adjusted coefficient of determination.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
Wikipedia coefficient of determination, Adjusted R^2
New in version 2.4.0.
residualsResiduals (label - predicted value)
New in version 2.0.0.
rootMeanSquaredErrorReturns the root mean squared error, which is defined as the square root of the mean squared error.
Notes
This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.
New in version 2.0.0.
tValuesT-statistic of estimated coefficients and intercept. This value is only available when using the “normal”solver.
If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.
New in version 2.0.0.
See also:
LinearRegression.solver
totalIterationsNumber of training iterations until termination. This value is only available when using the “l-bfgs” solver.
New in version 2.0.0.
See also:
LinearRegression.solver
910 Chapter 3. API Reference
pyspark Documentation, Release master
RandomForestRegressor
class pyspark.ml.regression.RandomForestRegressor(*args, **kwargs)Random Forest learning algorithm for regression. It supports both continuous and categorical features.
New in version 1.4.0.
Examples
>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)>>> rf.getMinWeightFractionPerNode()0.0>>> rf.setSeed(42)RandomForestRegressor...>>> model = rf.fit(df)>>> model.getBootstrap()True>>> model.getSeed()42>>> model.setLeafCol("leafId")RandomForestRegressionModel...>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> allclose(model.treeWeights, [1.0, 1.0])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictLeaf(test0.head().features)DenseVector([0.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.leafIdDenseVector([0.0, 0.0])>>> model.numFeatures1>>> model.trees[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]>>> model.getNumTrees2>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction0.5>>> rfr_path = temp_path + "/rfr">>> rf.save(rfr_path)>>> rf2 = RandomForestRegressor.load(rfr_path)>>> rf2.getNumTrees()2>>> model_path = temp_path + "/rfr_model">>> model.save(model_path)
(continues on next page)
3.3. ML 911
pyspark Documentation, Release master
(continued from previous page)
>>> model2 = RandomForestRegressionModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getNumTrees() Gets the value of numTrees or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.continues on next page
912 Chapter 3. API Reference
pyspark Documentation, Release master
Table 327 – continued from previous pagegetParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default
value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBootstrap(value) Sets the value of bootstrap.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of
minWeightFractionPerNode.setNumTrees(value) Sets the value of numTrees.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for linear regression.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setSubsamplingRate(value) Sets the value of subsamplingRate.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
bootstrapcacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpurity
continues on next page
3.3. ML 913
pyspark Documentation, Release master
Table 328 – continued from previous pagelabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumTreesparams Returns all params ordered by name.predictionColseedsubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiesweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
914 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getBootstrap()Gets the value of bootstrap or its default value.
New in version 3.0.0.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
3.3. ML 915
pyspark Documentation, Release master
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getNumTrees()Gets the value of numTrees or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBootstrap(value)Sets the value of bootstrap.
916 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.0.0.
setCacheNodeIds(value)Sets the value of cacheNodeIds.
setCheckpointInterval(value)Sets the value of checkpointInterval.
setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setImpurity(value)Sets the value of impurity .
New in version 1.4.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setMaxBins(value)Sets the value of maxBins.
setMaxDepth(value)Sets the value of maxDepth.
setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.
setMinInfoGain(value)Sets the value of minInfoGain.
setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.
setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.
New in version 3.0.0.
setNumTrees(value)Sets the value of numTrees.
New in version 1.4.0.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",subsamplingRate=1.0, seed=None, numTrees=20, featureSubsetStrategy="auto", leaf-Col="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
Sets params for linear regression.
New in version 1.4.0.
3.3. ML 917
pyspark Documentation, Release master
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
setSubsamplingRate(value)Sets the value of subsamplingRate.
New in version 1.4.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['variance']
918 Chapter 3. API Reference
pyspark Documentation, Release master
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
RandomForestRegressionModel
class pyspark.ml.regression.RandomForestRegressionModel(java_model=None)Model fitted by RandomForestRegressor.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default
value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default
value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default
value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default
value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its
default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.
continues on next page
3.3. ML 919
pyspark Documentation, Release master
Table 329 – continued from previous pagegetSubsamplingRate() Gets the value of subsamplingRate or its default
value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the
feature vector.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
bootstrapcacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.impuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumFeatures Returns the number of features the model was trained
on.numTreesparams Returns all params ordered by name.predictionColseed
continues on next page
920 Chapter 3. API Reference
pyspark Documentation, Release master
Table 330 – continued from previous pagesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiestoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the
ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.weightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBootstrap()Gets the value of bootstrap or its default value.
New in version 3.0.0.
getCacheNodeIds()Gets the value of cacheNodeIds or its default value.
getCheckpointInterval()Gets the value of checkpointInterval or its default value.
3.3. ML 921
pyspark Documentation, Release master
getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.
New in version 1.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getImpurity()Gets the value of impurity or its default value.
New in version 1.4.0.
getLabelCol()Gets the value of labelCol or its default value.
getLeafCol()Gets the value of leafCol or its default value.
getMaxBins()Gets the value of maxBins or its default value.
getMaxDepth()Gets the value of maxDepth or its default value.
getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.
getMinInfoGain()Gets the value of minInfoGain or its default value.
getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.
getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getSeed()Gets the value of seed or its default value.
getSubsamplingRate()Gets the value of subsamplingRate or its default value.
New in version 1.4.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
922 Chapter 3. API Reference
pyspark Documentation, Release master
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setLeafCol(value)Sets the value of leafCol.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
3.3. ML 923
pyspark Documentation, Release master
Attributes Documentation
bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')
cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')
checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')
featureImportancesEstimate of the importance of each feature.
Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.
New in version 2.0.0.
Examples
DecisionTreeRegressionModel.featureImportances
featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
getNumTreesNumber of trees in ensemble.
New in version 2.0.0.
impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')
maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')
maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')
maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')
minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')
minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')
minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
seed = Param(parent='undefined', name='seed', doc='random seed.')
subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')
924 Chapter 3. API Reference
pyspark Documentation, Release master
supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']
supportedImpurities = ['variance']
toDebugStringFull description of model.
New in version 2.0.0.
totalNumNodesTotal number of nodes, summed over all trees in the ensemble.
New in version 2.0.0.
treeWeightsReturn the weights for each tree
New in version 1.5.0.
treesTrees in this ensemble. Warning: These have null parent Estimators.
New in version 2.0.0.
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
FMRegressor
class pyspark.ml.regression.FMRegressor(*args, **kwargs)Factorization Machines learning algorithm for regression.
solver Supports:
• gd (normal mini-batch gradient descent)
• adamW (default)
New in version 3.0.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.regression import FMRegressor>>> df = spark.createDataFrame([... (2.0, Vectors.dense(2.0)),... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>>>>> fm = FMRegressor(factorSize=2)>>> fm.setSeed(16)FMRegressor...>>> model = fm.fit(df)>>> model.getMaxIter()100>>> test0 = spark.createDataFrame([... (Vectors.dense(-2.0),),... (Vectors.dense(0.5),),... (Vectors.dense(1.0),),... (Vectors.dense(4.0),)], ["features"])>>> model.transform(test0).show(10, False)
(continues on next page)
3.3. ML 925
pyspark Documentation, Release master
(continued from previous page)
+--------+-------------------+|features|prediction |+--------+-------------------+|[-2.0] |-1.9989237712341565||[0.5] |0.4956682219523814 ||[1.0] |0.994586620589689 ||[4.0] |3.9880970124135344 |+--------+-------------------+...>>> model.intercept-0.0032501766849261557>>> model.linearDenseVector([0.9978])>>> model.factorsDenseMatrix(1, 2, [0.0173, 0.0021], 1)>>> model_path = temp_path + "/fm_model">>> model.save(model_path)>>> model2 = FMRegressionModel.load(model_path)>>> model2.intercept-0.0032501766849261557>>> model2.linearDenseVector([0.9978])>>> model2.factorsDenseMatrix(1, 2, [0.0173, 0.0021], 1)>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.
continues on next page
926 Chapter 3. API Reference
pyspark Documentation, Release master
Table 331 – continued from previous pagegetInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFactorSize(value) Sets the value of factorSize.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setFitLinear(value) Sets the value of fitLinear.setInitStd(value) Sets the value of initStd.setLabelCol(value) Sets the value of labelCol.setMaxIter(value) Sets the value of maxIter.setMiniBatchFraction(value) Sets the value of miniBatchFraction.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets Params for FMRegressor.setPredictionCol(value) Sets the value of predictionCol.setRegParam(value) Sets the value of regParam.setSeed(value) Sets the value of seed.setSolver(value) Sets the value of solver.setStepSize(value) Sets the value of stepSize.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.
3.3. ML 927
pyspark Documentation, Release master
Attributes
factorSizefeaturesColfitInterceptfitLinearinitStdlabelColmaxIterminiBatchFractionparams Returns all params ordered by name.predictionColregParamseedsolverstepSizetolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
928 Chapter 3. API Reference
pyspark Documentation, Release master
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getFactorSize()Gets the value of factorSize or its default value.
New in version 3.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getFitLinear()Gets the value of fitLinear or its default value.
New in version 3.0.0.
getInitStd()Gets the value of initStd or its default value.
New in version 3.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.
New in version 3.0.0.
3.3. ML 929
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSeed()Gets the value of seed or its default value.
getSolver()Gets the value of solver or its default value.
getStepSize()Gets the value of stepSize or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFactorSize(value)Sets the value of factorSize.
New in version 3.0.0.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setFitIntercept(value)Sets the value of fitIntercept.
930 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 3.0.0.
setFitLinear(value)Sets the value of fitLinear.
New in version 3.0.0.
setInitStd(value)Sets the value of initStd.
New in version 3.0.0.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setMaxIter(value)Sets the value of maxIter.
New in version 3.0.0.
setMiniBatchFraction(value)Sets the value of miniBatchFraction.
New in version 3.0.0.
setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", factor-Size=8, fitIntercept=True, fitLinear=True, regParam=0.0, miniBatchFraction=1.0, init-Std=0.01, maxIter=100, stepSize=1.0, tol=1e-6, solver="adamW", seed=None)
Sets Params for FMRegressor.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
setRegParam(value)Sets the value of regParam.
New in version 3.0.0.
setSeed(value)Sets the value of seed.
New in version 3.0.0.
setSolver(value)Sets the value of solver.
New in version 3.0.0.
setStepSize(value)Sets the value of stepSize.
New in version 3.0.0.
setTol(value)Sets the value of tol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
3.3. ML 931
pyspark Documentation, Release master
Attributes Documentation
factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')
initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
seed = Param(parent='undefined', name='seed', doc='random seed.')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
FMRegressionModel
class pyspark.ml.regression.FMRegressionModel(java_model=None)Model fitted by FMRegressor.
New in version 3.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
932 Chapter 3. API Reference
pyspark Documentation, Release master
Table 333 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.getInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default
value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
3.3. ML 933
pyspark Documentation, Release master
Attributes
factorSizefactors Model factor term.featuresColfitInterceptfitLinearinitStdintercept Model intercept.labelCollinear Model linear term.maxIterminiBatchFractionnumFeatures Returns the number of features the model was trained
on.params Returns all params ordered by name.predictionColregParamseedsolverstepSizetolweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
934 Chapter 3. API Reference
pyspark Documentation, Release master
extra [dict, optional] extra param values
Returns
dict merged param map
getFactorSize()Gets the value of factorSize or its default value.
New in version 3.0.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getFitIntercept()Gets the value of fitIntercept or its default value.
getFitLinear()Gets the value of fitLinear or its default value.
New in version 3.0.0.
getInitStd()Gets the value of initStd or its default value.
New in version 3.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getMaxIter()Gets the value of maxIter or its default value.
getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getRegParam()Gets the value of regParam or its default value.
getSeed()Gets the value of seed or its default value.
getSolver()Gets the value of solver or its default value.
getStepSize()Gets the value of stepSize or its default value.
getTol()Gets the value of tol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
3.3. ML 935
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
predict(value)Predict label for the given features.
New in version 3.0.0.
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setFeaturesCol(value)Sets the value of featuresCol.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
936 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')
factorsModel factor term.
New in version 3.0.0.
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')
fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')
initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')
interceptModel intercept.
New in version 3.0.0.
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
linearModel linear term.
New in version 3.0.0.
maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')
miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')
numFeaturesReturns the number of features the model was trained on. If unknown, returns -1
New in version 2.1.0.
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')
seed = Param(parent='undefined', name='seed', doc='random seed.')
solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')
stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')
tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 937
pyspark Documentation, Release master
3.3.10 Statistics
ANOVATest() Conduct ANOVA Classification Test for continuous fea-tures against categorical labels.
ChiSquareTest() Conduct Pearson’s independence test for every featureagainst the label.
Correlation() Compute the correlation matrix for the input dataset ofVectors using the specified method.
FValueTest() Conduct F Regression test for continuous featuresagainst continuous labels.
KolmogorovSmirnovTest() Conduct the two-sided Kolmogorov Smirnov (KS) testfor data sampled from a continuous distribution.
MultivariateGaussian(mean, cov) Represents a (mean, cov) tupleSummarizer() Tools for vectorized statistics on MLlib Vectors.SummaryBuilder(jSummaryBuilder) A builder object that provides summary statistics about
a given column.
ANOVATest
class pyspark.ml.stat.ANOVATestConduct ANOVA Classification Test for continuous features against categorical labels.
New in version 3.1.0.
Methods
test(dataset, featuresCol, labelCol[, flatten]) Perform an ANOVA test using dataset.
Methods Documentation
static test(dataset, featuresCol, labelCol, flatten=False)Perform an ANOVA test using dataset.
New in version 3.1.0.
Parameters
dataset [pyspark.sql.DataFrame] DataFrame of categorical labels and continuousfeatures.
featuresCol [str] Name of features column in dataset, of type Vector (VectorUDT).
labelCol [str] Name of label column in dataset, of any numerical type.
flatten [bool, optional] if True, flattens the returned dataframe.
Returns
pyspark.sql.DataFrame DataFrame containing the test result for every featureagainst the label. If flatten is True, this DataFrame will contain one row per feature withthe following fields:
• featureIndex: int
• pValue: float
938 Chapter 3. API Reference
pyspark Documentation, Release master
• degreesOfFreedom: int
• fValue: float
If flatten is False, this DataFrame will contain a single Row with the following fields:
• pValues: Vector
• degreesOfFreedom: Array[int]
• fValues: Vector
Each of these fields has one value per feature.
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.stat import ANOVATest>>> dataset = [[2.0, Vectors.dense([0.43486404, 0.57153633, 0.43175686,... 0.51418671, 0.61632374, 0.96565515])],... [1.0, Vectors.dense([0.49162732, 0.6785187, 0.85460572,... 0.59784822, 0.12394819, 0.53783355])],... [2.0, Vectors.dense([0.30879653, 0.54904515, 0.17103889,... 0.40492506, 0.18957493, 0.5440016])],... [3.0, Vectors.dense([0.68114391, 0.60549825, 0.69094651,... 0.62102109, 0.05471483, 0.96449167])]]>>> dataset = spark.createDataFrame(dataset, ["label", "features"])>>> anovaResult = ANOVATest.test(dataset, 'features', 'label')>>> row = anovaResult.select("fValues", "pValues").collect()>>> row[0].fValuesDenseVector([4.0264, 18.4713, 3.4659, 1.9042, 0.5532, 0.512])>>> row[0].pValuesDenseVector([0.3324, 0.1623, 0.3551, 0.456, 0.689, 0.7029])>>> anovaResult = ANOVATest.test(dataset, 'features', 'label', True)>>> row = anovaResult.orderBy("featureIndex").collect()>>> row[0].fValue4.026438671875297
ChiSquareTest
class pyspark.ml.stat.ChiSquareTestConduct Pearson’s independence test for every feature against the label. For each feature, the (feature, label)pairs are converted into a contingency matrix for which the Chi-squared statistic is computed. All label andfeature values must be categorical.
The null hypothesis is that the occurrence of the outcomes is statistically independent.
New in version 2.2.0.
3.3. ML 939
pyspark Documentation, Release master
Methods
test(dataset, featuresCol, labelCol[, flatten]) Perform a Pearson’s independence test using dataset.
Methods Documentation
static test(dataset, featuresCol, labelCol, flatten=False)Perform a Pearson’s independence test using dataset.
New in version 2.2.0.
Changed in version 3.1.0: Added optional flatten argument.
Parameters
dataset [pyspark.sql.DataFrame] DataFrame of categorical labels and categoricalfeatures. Real-valued features will be treated as categorical for each distinct value.
featuresCol [str] Name of features column in dataset, of type Vector (VectorUDT).
labelCol [str] Name of label column in dataset, of any numerical type.
flatten [bool, optional] if True, flattens the returned dataframe.
Returns
pyspark.sql.DataFrame DataFrame containing the test result for every featureagainst the label. If flatten is True, this DataFrame will contain one row per feature withthe following fields:
• featureIndex: int
• pValue: float
• degreesOfFreedom: int
• statistic: float
If flatten is False, this DataFrame will contain a single Row with the following fields:
• pValues: Vector
• degreesOfFreedom: Array[int]
• statistics: Vector
Each of these fields has one value per feature.
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.stat import ChiSquareTest>>> dataset = [[0, Vectors.dense([0, 0, 1])],... [0, Vectors.dense([1, 0, 1])],... [1, Vectors.dense([2, 1, 1])],... [1, Vectors.dense([3, 1, 1])]]>>> dataset = spark.createDataFrame(dataset, ["label", "features"])>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')>>> chiSqResult.select("degreesOfFreedom").collect()[0]Row(degreesOfFreedom=[3, 1, 0])
(continues on next page)
940 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label', True)>>> row = chiSqResult.orderBy("featureIndex").collect()>>> row[0].statistic4.0
Correlation
class pyspark.ml.stat.CorrelationCompute the correlation matrix for the input dataset of Vectors using the specified method. Methods currentlysupported: pearson (default), spearman.
New in version 2.2.0.
Notes
For Spearman, a rank correlation, we need to create an RDD[Double] for each column and sort it in order toretrieve the ranks and then join the columns back into an RDD[Vector], which is fairly costly. Cache the inputDataset before calling corr with method = ‘spearman’ to avoid recomputing the common lineage.
Methods
corr(dataset, column[, method]) Compute the correlation matrix with specifiedmethod using dataset.
Methods Documentation
static corr(dataset, column, method='pearson')Compute the correlation matrix with specified method using dataset.
New in version 2.2.0.
Parameters
dataset [pyspark.sql.DataFrame] A DataFrame.
column [str] The name of the column of vectors for which the correlation coefficient needsto be computed. This must be a column of the dataset, and it must contain Vector objects.
method [str, optional] String specifying the method to use for computing correlation. Sup-ported: pearson (default), spearman.
Returns
A DataFrame that contains the correlation matrix of the column of vectors. This
DataFrame contains a single row and a single column of name METHODNAME(COLUMN).
3.3. ML 941
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.linalg import DenseMatrix, Vectors>>> from pyspark.ml.stat import Correlation>>> dataset = [[Vectors.dense([1, 0, 0, -2])],... [Vectors.dense([4, 5, 0, 3])],... [Vectors.dense([6, 7, 0, 8])],... [Vectors.dense([9, 0, 0, 1])]]>>> dataset = spark.createDataFrame(dataset, ['features'])>>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').→˓collect()[0][0]>>> print(str(pearsonCorr).replace('nan', 'NaN'))DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
[ 0.0556..., 1. , NaN, 0.9135...],[ NaN, NaN, 1. , NaN],[ 0.4004..., 0.9135..., NaN, 1. ]])
>>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').→˓collect()[0][0]>>> print(str(spearmanCorr).replace('nan', 'NaN'))DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
[ 0.1054..., 1. , NaN, 0.9486... ],[ NaN, NaN, 1. , NaN],[ 0.4 , 0.9486... , NaN, 1. ]])
FValueTest
class pyspark.ml.stat.FValueTestConduct F Regression test for continuous features against continuous labels.
New in version 3.1.0.
Methods
test(dataset, featuresCol, labelCol[, flatten]) Perform a F Regression test using dataset.
Methods Documentation
static test(dataset, featuresCol, labelCol, flatten=False)Perform a F Regression test using dataset.
New in version 3.1.0.
Parameters
dataset [pyspark.sql.DataFrame] DataFrame of continuous labels and continuousfeatures.
featuresCol [str] Name of features column in dataset, of type Vector (VectorUDT).
labelCol [str] Name of label column in dataset, of any numerical type.
flatten [bool, optional] if True, flattens the returned dataframe.
Returns
942 Chapter 3. API Reference
pyspark Documentation, Release master
pyspark.sql.DataFrame DataFrame containing the test result for every featureagainst the label. If flatten is True, this DataFrame will contain one row per feature withthe following fields:
• featureIndex: int
• pValue: float
• degreesOfFreedom: int
• fValue: float
If flatten is False, this DataFrame will contain a single Row with the following fields:
• pValues: Vector
• degreesOfFreedom: Array[int]
• fValues: Vector
Each of these fields has one value per feature.
Examples
>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.stat import FValueTest>>> dataset = [[0.57495218, Vectors.dense([0.43486404, 0.57153633, 0.43175686,... 0.51418671, 0.61632374, 0.→˓96565515])],... [0.84619853, Vectors.dense([0.49162732, 0.6785187, 0.85460572,... 0.59784822, 0.12394819, 0.→˓53783355])],... [0.39777647, Vectors.dense([0.30879653, 0.54904515, 0.17103889,... 0.40492506, 0.18957493, 0.→˓5440016])],... [0.79201573, Vectors.dense([0.68114391, 0.60549825, 0.69094651,... 0.62102109, 0.05471483, 0.→˓96449167])]]>>> dataset = spark.createDataFrame(dataset, ["label", "features"])>>> fValueResult = FValueTest.test(dataset, 'features', 'label')>>> row = fValueResult.select("fValues", "pValues").collect()>>> row[0].fValuesDenseVector([3.741, 7.5807, 142.0684, 34.9849, 0.4112, 0.0539])>>> row[0].pValuesDenseVector([0.1928, 0.1105, 0.007, 0.0274, 0.5871, 0.838])>>> fValueResult = FValueTest.test(dataset, 'features', 'label', True)>>> row = fValueResult.orderBy("featureIndex").collect()>>> row[0].fValue3.7409548308350593
3.3. ML 943
pyspark Documentation, Release master
KolmogorovSmirnovTest
class pyspark.ml.stat.KolmogorovSmirnovTestConduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous distribution.
By comparing the largest difference between the empirical cumulative distribution of the sample data and thetheoretical distribution we can provide a test for the the null hypothesis that the sample data comes from thattheoretical distribution.
New in version 2.4.0.
Methods
test(dataset, sampleCol, distName, *params) Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution equality.
Methods Documentation
static test(dataset, sampleCol, distName, *params)Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution equality. Currentlysupports the normal distribution, taking as parameters the mean and standard deviation.
New in version 2.4.0.
Parameters
dataset [pyspark.sql.DataFrame] a Dataset or a DataFrame containing the sampleof data to test.
sampleCol [str] Name of sample column in dataset, of any numerical type.
distName [str] a string name for a theoretical distribution, currently only support “norm”.
params [float] a list of float values specifying the parameters to be used for the theoreticaldistribution. For “norm” distribution, the parameters includes mean and variance.
Returns
A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data.
This DataFrame will contain a single Row with the following fields:
• pValue: Double
• statistic: Double
944 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.stat import KolmogorovSmirnovTest>>> dataset = [[-1.0], [0.0], [1.0]]>>> dataset = spark.createDataFrame(dataset, ['sample'])>>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.→˓0).first()>>> round(ksResult.pValue, 3)1.0>>> round(ksResult.statistic, 3)0.175>>> dataset = [[2.0], [3.0], [4.0]]>>> dataset = spark.createDataFrame(dataset, ['sample'])>>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.→˓0).first()>>> round(ksResult.pValue, 3)1.0>>> round(ksResult.statistic, 3)0.175
MultivariateGaussian
class pyspark.ml.stat.MultivariateGaussian(mean, cov)Represents a (mean, cov) tuple
New in version 3.0.0.
Examples
>>> from pyspark.ml.linalg import DenseMatrix, Vectors>>> m = MultivariateGaussian(Vectors.dense([11,12]), DenseMatrix(2, 2, (1.0, 3.0,→˓5.0, 2.0)))>>> (m.mean, m.cov.toArray())(DenseVector([11.0, 12.0]), array([[ 1., 5.],
[ 3., 2.]]))
Summarizer
class pyspark.ml.stat.SummarizerTools for vectorized statistics on MLlib Vectors. The methods in this package provide various statistics forVectors contained inside DataFrames. This class lets users pick the statistics they would like to extract for agiven column.
New in version 2.4.0.
3.3. ML 945
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.stat import Summarizer>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> summarizer = Summarizer.metrics("mean", "count")>>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).→˓toDF()>>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)+-----------------------------------+|aggregate_metrics(features, weight)|+-----------------------------------+|{[1.0,1.0,1.0], 1} |+-----------------------------------+
>>> df.select(summarizer.summary(df.features)).show(truncate=False)+--------------------------------+|aggregate_metrics(features, 1.0)|+--------------------------------+|{[1.0,1.5,2.0], 2} |+--------------------------------+
>>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)+--------------+|mean(features)|+--------------+|[1.0,1.0,1.0] |+--------------+
>>> df.select(Summarizer.mean(df.features)).show(truncate=False)+--------------+|mean(features)|+--------------+|[1.0,1.5,2.0] |+--------------+
Methods
count(col[, weightCol]) return a column of count summarymax(col[, weightCol]) return a column of max summarymean(col[, weightCol]) return a column of mean summarymetrics(*metrics) Given a list of metrics, provides a builder that it turns
computes metrics from a column.min(col[, weightCol]) return a column of min summarynormL1(col[, weightCol]) return a column of normL1 summarynormL2(col[, weightCol]) return a column of normL2 summarynumNonZeros(col[, weightCol]) return a column of numNonZero summarystd(col[, weightCol]) return a column of std summarysum(col[, weightCol]) return a column of sum summaryvariance(col[, weightCol]) return a column of variance summary
946 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
static count(col, weightCol=None)return a column of count summary
New in version 2.4.0.
static max(col, weightCol=None)return a column of max summary
New in version 2.4.0.
static mean(col, weightCol=None)return a column of mean summary
New in version 2.4.0.
static metrics(*metrics)Given a list of metrics, provides a builder that it turns computes metrics from a column.
See the documentation of Summarizer for an example.
The following metrics are accepted (case sensitive):
• mean: a vector that contains the coefficient-wise mean.
• sum: a vector that contains the coefficient-wise sum.
• variance: a vector tha contains the coefficient-wise variance.
• std: a vector tha contains the coefficient-wise standard deviation.
• count: the count of all vectors seen.
• numNonzeros: a vector with the number of non-zeros for each coefficients
• max: the maximum for each coefficient.
• min: the minimum for each coefficient.
• normL2: the Euclidean norm for each coefficient.
• normL1: the L1 norm of each coefficient (sum of the absolute values).
New in version 2.4.0.
Returns
pyspark.ml.stat.SummaryBuilder
Notes
Currently, the performance of this interface is about 2x~3x slower than using the RDD interface.
3.3. ML 947
pyspark Documentation, Release master
Examples
metrics [str] metrics that can be provided.
static min(col, weightCol=None)return a column of min summary
New in version 2.4.0.
static normL1(col, weightCol=None)return a column of normL1 summary
New in version 2.4.0.
static normL2(col, weightCol=None)return a column of normL2 summary
New in version 2.4.0.
static numNonZeros(col, weightCol=None)return a column of numNonZero summary
New in version 2.4.0.
static std(col, weightCol=None)return a column of std summary
New in version 3.0.0.
static sum(col, weightCol=None)return a column of sum summary
New in version 3.0.0.
static variance(col, weightCol=None)return a column of variance summary
New in version 2.4.0.
SummaryBuilder
class pyspark.ml.stat.SummaryBuilder(jSummaryBuilder)A builder object that provides summary statistics about a given column.
Users should not directly create such builders, but instead use one of the methods in pyspark.ml.stat.Summarizer
New in version 2.4.0.
Methods
summary(featuresCol[, weightCol]) Returns an aggregate object that contains the sum-mary of the column with the requested metrics.
948 Chapter 3. API Reference
pyspark Documentation, Release master
Methods Documentation
summary(featuresCol, weightCol=None)Returns an aggregate object that contains the summary of the column with the requested metrics.
New in version 2.4.0.
Parameters
featuresCol [str] a column that contains features Vector object.
weightCol [str, optional] a column that contains weight value. Default weight is 1.0.
Returns
pyspark.sql.Column an aggregate column that contains the statistics. The exact con-tent of this structure is determined during the creation of the builder.
3.3.11 Tuning
ParamGridBuilder() Builder for a param grid used in grid search-basedmodel selection.
CrossValidator(*args, **kwargs) K-fold cross validation performs model selection bysplitting the dataset into a set of non-overlapping ran-domly partitioned folds which are used as separate train-ing and test datasets e.g., with k=3 folds, K-fold crossvalidation will generate 3 (training, test) dataset pairs,each of which uses 2/3 of the data for training and 1/3for testing.
CrossValidatorModel(bestModel[, avgMetrics,. . . ])
CrossValidatorModel contains the model with the high-est average cross-validation metric across folds and usesthis model to transform input data.
TrainValidationSplit(*args, **kwargs) Validation for hyper-parameter tuning.TrainValidationSplitModel(bestModel[, . . . ]) Model from train validation split.
ParamGridBuilder
class pyspark.ml.tuning.ParamGridBuilderBuilder for a param grid used in grid search-based model selection.
New in version 1.4.0.
Examples
>>> from pyspark.ml.classification import LogisticRegression>>> lr = LogisticRegression()>>> output = ParamGridBuilder() \... .baseOn({lr.labelCol: 'l'}) \... .baseOn([lr.predictionCol, 'p']) \... .addGrid(lr.regParam, [1.0, 2.0]) \... .addGrid(lr.maxIter, [1, 5]) \... .build()>>> expected = [... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓, (continues on next page)
3.3. ML 949
pyspark Documentation, Release master
(continued from previous page)
... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓,... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓,... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓]>>> len(output) == len(expected)True>>> all([m in expected for m in output])True
Methods
addGrid(param, values) Sets the given parameters in this grid to fixed values.baseOn(*args) Sets the given parameters in this grid to fixed values.build() Builds and returns all combinations of parameters
specified by the param grid.
Methods Documentation
addGrid(param, values)Sets the given parameters in this grid to fixed values.
param must be an instance of Param associated with an instance of Params (such as Estimator or Trans-former).
New in version 1.4.0.
baseOn(*args)Sets the given parameters in this grid to fixed values. Accepts either a parameter dictionary or a list of(parameter, value) pairs.
New in version 1.4.0.
build()Builds and returns all combinations of parameters specified by the param grid.
New in version 1.4.0.
CrossValidator
class pyspark.ml.tuning.CrossValidator(*args, **kwargs)K-fold cross validation performs model selection by splitting the dataset into a set of non-overlapping randomlypartitioned folds which are used as separate training and test datasets e.g., with k=3 folds, K-fold cross validationwill generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.Each fold is used as the test set exactly once.
New in version 1.4.0.
950 Chapter 3. API Reference
pyspark Documentation, Release master
Examples
>>> from pyspark.ml.classification import LogisticRegression>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.tuning import CrossValidatorModel>>> import tempfile>>> dataset = spark.createDataFrame(... [(Vectors.dense([0.0]), 0.0),... (Vectors.dense([0.4]), 1.0),... (Vectors.dense([0.5]), 0.0),... (Vectors.dense([0.6]), 1.0),... (Vectors.dense([1.0]), 1.0)] * 10,... ["features", "label"])>>> lr = LogisticRegression()>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()>>> evaluator = BinaryClassificationEvaluator()>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid,→˓evaluator=evaluator,... parallelism=2)>>> cvModel = cv.fit(dataset)>>> cvModel.getNumFolds()3>>> cvModel.avgMetrics[0]0.5>>> path = tempfile.mkdtemp()>>> model_path = path + "/model">>> cvModel.write().save(model_path)>>> cvModelRead = CrossValidatorModel.read().load(model_path)>>> cvModelRead.avgMetrics[0.5, ...>>> evaluator.evaluate(cvModel.transform(dataset))0.8333...>>> evaluator.evaluate(cvModelRead.transform(dataset))0.8333...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
continues on next page
3.3. ML 951
pyspark Documentation, Release master
Table 345 – continued from previous pagefit(dataset[, params]) Fits a model to the input dataset with optional param-
eters.fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map
in paramMaps.getCollectSubModels() Gets the value of collectSubModels or its default
value.getEstimator() Gets the value of estimator or its default value.getEstimatorParamMaps() Gets the value of estimatorParamMaps or its default
value.getEvaluator() Gets the value of evaluator or its default value.getFoldCol() Gets the value of foldCol or its default value.getNumFolds() Gets the value of numFolds or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParallelism() Gets the value of parallelism or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCollectSubModels(value) Sets the value of collectSubModels.setEstimator(value) Sets the value of estimator.setEstimatorParamMaps(value) Sets the value of estimatorParamMaps.setEvaluator(value) Sets the value of evaluator.setFoldCol(value) Sets the value of foldCol.setNumFolds(value) Sets the value of numFolds.setParallelism(value) Sets the value of parallelism.setParams(*args, **kwargs) setParams(self, *, estimator=None, estimator-
ParamMaps=None, evaluator=None, numFolds=3,seed=None, parallelism=1, collectSubMod-els=False, foldCol=””): Sets params for crossvalidator.
setSeed(value) Sets the value of seed.write() Returns an MLWriter instance for this ML instance.
952 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes
collectSubModelsestimatorestimatorParamMapsevaluatorfoldColnumFoldsparallelismparams Returns all params ordered by name.seed
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies createsa deep copy of the embedded paramMap, and copies the embedded and extra parameters over.
New in version 1.4.0.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
CrossValidator Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
3.3. ML 953
pyspark Documentation, Release master
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCollectSubModels()Gets the value of collectSubModels or its default value.
getEstimator()Gets the value of estimator or its default value.
New in version 2.0.0.
getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.
New in version 2.0.0.
getEvaluator()Gets the value of evaluator or its default value.
New in version 2.0.0.
getFoldCol()Gets the value of foldCol or its default value.
New in version 3.1.0.
getNumFolds()Gets the value of numFolds or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParallelism()Gets the value of parallelism or its default value.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
954 Chapter 3. API Reference
pyspark Documentation, Release master
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
New in version 2.3.0.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCollectSubModels(value)Sets the value of collectSubModels.
setEstimator(value)Sets the value of estimator.
New in version 2.0.0.
setEstimatorParamMaps(value)Sets the value of estimatorParamMaps.
New in version 2.0.0.
setEvaluator(value)Sets the value of evaluator.
New in version 2.0.0.
setFoldCol(value)Sets the value of foldCol.
New in version 3.1.0.
setNumFolds(value)Sets the value of numFolds.
New in version 1.4.0.
setParallelism(value)Sets the value of parallelism.
setParams(*args, **kwargs)setParams(self, *, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,seed=None, parallelism=1, collectSubModels=False, foldCol=””): Sets params for cross validator.
New in version 1.4.0.
setSeed(value)Sets the value of seed.
3.3. ML 955
pyspark Documentation, Release master
write()Returns an MLWriter instance for this ML instance.
New in version 2.3.0.
Attributes Documentation
collectSubModels = Param(parent='undefined', name='collectSubModels', doc='Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.')
estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')
estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')
evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')
foldCol = Param(parent='undefined', name='foldCol', doc="Param for the column name of user specified fold number. Once this is specified, :py:class:`CrossValidator` won't do random k-fold split. Note that this column should be integer type with range [0, numFolds) and Spark will throw exception on out-of-range fold numbers.")
numFolds = Param(parent='undefined', name='numFolds', doc='number of folds for cross validation')
parallelism = Param(parent='undefined', name='parallelism', doc='the number of threads to use when running parallel algorithms (>= 1).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
CrossValidatorModel
class pyspark.ml.tuning.CrossValidatorModel(bestModel, avgMetrics=None, subMod-els=None)
CrossValidatorModel contains the model with the highest average cross-validation metric across folds and usesthis model to transform input data. CrossValidatorModel also tracks the metrics for each param map evaluated.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getEstimator() Gets the value of estimator or its default value.continues on next page
956 Chapter 3. API Reference
pyspark Documentation, Release master
Table 347 – continued from previous pagegetEstimatorParamMaps() Gets the value of estimatorParamMaps or its default
value.getEvaluator() Gets the value of evaluator or its default value.getFoldCol() Gets the value of foldCol or its default value.getNumFolds() Gets the value of numFolds or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
estimatorestimatorParamMapsevaluatorfoldColnumFoldsparams Returns all params ordered by name.seed
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies theunderlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded andextra parameters over. It does not copy the extra Params into the subModels.
New in version 1.4.0.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
3.3. ML 957
pyspark Documentation, Release master
CrossValidatorModel Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getEstimator()Gets the value of estimator or its default value.
New in version 2.0.0.
getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.
New in version 2.0.0.
getEvaluator()Gets the value of evaluator or its default value.
New in version 2.0.0.
getFoldCol()Gets the value of foldCol or its default value.
New in version 3.1.0.
getNumFolds()Gets the value of numFolds or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
958 Chapter 3. API Reference
pyspark Documentation, Release master
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
New in version 2.3.0.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
New in version 2.3.0.
Attributes Documentation
estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')
estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')
evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')
foldCol = Param(parent='undefined', name='foldCol', doc="Param for the column name of user specified fold number. Once this is specified, :py:class:`CrossValidator` won't do random k-fold split. Note that this column should be integer type with range [0, numFolds) and Spark will throw exception on out-of-range fold numbers.")
numFolds = Param(parent='undefined', name='numFolds', doc='number of folds for cross validation')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
3.3. ML 959
pyspark Documentation, Release master
TrainValidationSplit
class pyspark.ml.tuning.TrainValidationSplit(*args, **kwargs)Validation for hyper-parameter tuning. Randomly splits the input dataset into train and validation sets, and usesevaluation metric on the validation set to select the best model. Similar to CrossValidator, but only splitsthe set once.
New in version 2.0.0.
Examples
>>> from pyspark.ml.classification import LogisticRegression>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.tuning import TrainValidationSplitModel>>> import tempfile>>> dataset = spark.createDataFrame(... [(Vectors.dense([0.0]), 0.0),... (Vectors.dense([0.4]), 1.0),... (Vectors.dense([0.5]), 0.0),... (Vectors.dense([0.6]), 1.0),... (Vectors.dense([1.0]), 1.0)] * 10,... ["features", "label"]).repartition(1)>>> lr = LogisticRegression()>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()>>> evaluator = BinaryClassificationEvaluator()>>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid,→˓evaluator=evaluator,... parallelism=1, seed=42)>>> tvsModel = tvs.fit(dataset)>>> tvsModel.getTrainRatio()0.75>>> tvsModel.validationMetrics[0.5, ...>>> path = tempfile.mkdtemp()>>> model_path = path + "/model">>> tvsModel.write().save(model_path)>>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)>>> tvsModelRead.validationMetrics[0.5, ...>>> evaluator.evaluate(tvsModel.transform(dataset))0.833...>>> evaluator.evaluate(tvsModelRead.transform(dataset))0.833...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.
continues on next page
960 Chapter 3. API Reference
pyspark Documentation, Release master
Table 349 – continued from previous pageexplainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.
fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.
getCollectSubModels() Gets the value of collectSubModels or its defaultvalue.
getEstimator() Gets the value of estimator or its default value.getEstimatorParamMaps() Gets the value of estimatorParamMaps or its default
value.getEvaluator() Gets the value of evaluator or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParallelism() Gets the value of parallelism or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getTrainRatio() Gets the value of trainRatio or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCollectSubModels(value) Sets the value of collectSubModels.setEstimator(value) Sets the value of estimator.setEstimatorParamMaps(value) Sets the value of estimatorParamMaps.setEvaluator(value) Sets the value of evaluator.setParallelism(value) Sets the value of parallelism.setParams(*args, **kwargs) setParams(self, *, estimator=None, estimator-
ParamMaps=None, evaluator=None, trainRa-tio=0.75, parallelism=1, collectSubModels=False,seed=None): Sets params for the train validationsplit.
setSeed(value) Sets the value of seed.setTrainRatio(value) Sets the value of trainRatio.
continues on next page
3.3. ML 961
pyspark Documentation, Release master
Table 349 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.
Attributes
collectSubModelsestimatorestimatorParamMapsevaluatorparallelismparams Returns all params ordered by name.seedtrainRatio
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies createsa deep copy of the embedded paramMap, and copies the embedded and extra parameters over.
New in version 2.0.0.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
TrainValidationSplit Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
fit(dataset, params=None)Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
Parameters
962 Chapter 3. API Reference
pyspark Documentation, Release master
dataset [pyspark.sql.DataFrame] input dataset.
params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.
Returns
Transformer or a list of Transformer fitted model(s)
fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset.
paramMaps [collections.abc.Sequence] A Sequence of param maps.
Returns
_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.
getCollectSubModels()Gets the value of collectSubModels or its default value.
getEstimator()Gets the value of estimator or its default value.
New in version 2.0.0.
getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.
New in version 2.0.0.
getEvaluator()Gets the value of evaluator or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParallelism()Gets the value of parallelism or its default value.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getTrainRatio()Gets the value of trainRatio or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
3.3. ML 963
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
New in version 2.3.0.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setCollectSubModels(value)Sets the value of collectSubModels.
setEstimator(value)Sets the value of estimator.
New in version 2.0.0.
setEstimatorParamMaps(value)Sets the value of estimatorParamMaps.
New in version 2.0.0.
setEvaluator(value)Sets the value of evaluator.
New in version 2.0.0.
setParallelism(value)Sets the value of parallelism.
setParams(*args, **kwargs)setParams(self, *, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, paral-lelism=1, collectSubModels=False, seed=None): Sets params for the train validation split.
New in version 2.0.0.
setSeed(value)Sets the value of seed.
setTrainRatio(value)Sets the value of trainRatio.
New in version 2.0.0.
write()Returns an MLWriter instance for this ML instance.
New in version 2.3.0.
964 Chapter 3. API Reference
pyspark Documentation, Release master
Attributes Documentation
collectSubModels = Param(parent='undefined', name='collectSubModels', doc='Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.')
estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')
estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')
evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')
parallelism = Param(parent='undefined', name='parallelism', doc='the number of threads to use when running parallel algorithms (>= 1).')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
trainRatio = Param(parent='undefined', name='trainRatio', doc='Param for ratio between train and validation data. Must be between 0 and 1.')
TrainValidationSplitModel
class pyspark.ml.tuning.TrainValidationSplitModel(bestModel, validation-Metrics=None, subModels=None)
Model from train validation split.
New in version 2.0.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.
explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getEstimator() Gets the value of estimator or its default value.getEstimatorParamMaps() Gets the value of estimatorParamMaps or its default
value.getEvaluator() Gets the value of evaluator or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getTrainRatio() Gets the value of trainRatio or its default value.hasDefault(param) Checks whether a param has a default value.
continues on next page
3.3. ML 965
pyspark Documentation, Release master
Table 351 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-
ters.write() Returns an MLWriter instance for this ML instance.
Attributes
estimatorestimatorParamMapsevaluatorparams Returns all params ordered by name.seedtrainRatio
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies theunderlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded andextra parameters over. And, this creates a shallow copy of the validationMetrics. It does not copy the extraParams into the subModels.
New in version 2.0.0.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
TrainValidationSplitModel Copy of this instance
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra
966 Chapter 3. API Reference
pyspark Documentation, Release master
values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getEstimator()Gets the value of estimator or its default value.
New in version 2.0.0.
getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.
New in version 2.0.0.
getEvaluator()Gets the value of evaluator or its default value.
New in version 2.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getSeed()Gets the value of seed or its default value.
getTrainRatio()Gets the value of trainRatio or its default value.
New in version 2.0.0.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
New in version 2.3.0.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
3.3. ML 967
pyspark Documentation, Release master
transform(dataset, params=None)Transforms the input dataset with optional parameters.
New in version 1.3.0.
Parameters
dataset [pyspark.sql.DataFrame] input dataset
params [dict, optional] an optional param map that overrides embedded params.
Returns
pyspark.sql.DataFrame transformed dataset
write()Returns an MLWriter instance for this ML instance.
New in version 2.3.0.
Attributes Documentation
estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')
estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')
evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
seed = Param(parent='undefined', name='seed', doc='random seed.')
trainRatio = Param(parent='undefined', name='trainRatio', doc='Param for ratio between train and validation data. Must be between 0 and 1.')
3.3.12 Evaluation
Evaluator() Base class for evaluators that compute metrics from pre-dictions.
BinaryClassificationEvaluator(*args,**kwargs)
Evaluator for binary classification, which expects inputcolumns rawPrediction, label and an optional weightcolumn.
RegressionEvaluator(*args, **kwargs) Evaluator for Regression, which expects input columnsprediction, label and an optional weight column.
MulticlassClassificationEvaluator(*args,. . . )
Evaluator for Multiclass Classification, which expectsinput columns: prediction, label, weight (optional) andprobabilityCol (only for logLoss).
MultilabelClassificationEvaluator(*args,. . . )
Evaluator for Multilabel Classification, which expectstwo input columns: prediction and label.
ClusteringEvaluator(*args, **kwargs) Evaluator for Clustering results, which expects two in-put columns: prediction and features.
RankingEvaluator(*args, **kwargs) Evaluator for Ranking, which expects two inputcolumns: prediction and label.
968 Chapter 3. API Reference
pyspark Documentation, Release master
Evaluator
class pyspark.ml.evaluation.EvaluatorBase class for evaluators that compute metrics from predictions.
New in version 1.4.0.
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.
getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isLargerBetter() Indicates whether the metric returned by
evaluate() should be maximized (True, de-fault) or minimized (False).
isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.
Attributes
params Returns all params ordered by name.
3.3. ML 969
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
Params Copy of this instance
evaluate(dataset, params=None)Evaluates the output with optional parameters.
New in version 1.4.0.
Parameters
dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions
params [dict, optional] an optional param map that overrides embedded params
Returns
float metric
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
hasDefault(param)Checks whether a param has a default value.
970 Chapter 3. API Reference
pyspark Documentation, Release master
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.
New in version 1.5.0.
isSet(param)Checks whether a param is explicitly set by user.
set(param, value)Sets a parameter in the embedded param map.
Attributes Documentation
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
BinaryClassificationEvaluator
class pyspark.ml.evaluation.BinaryClassificationEvaluator(*args, **kwargs)Evaluator for binary classification, which expects input columns rawPrediction, label and an optional weightcolumn. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or oftype vector (length-2 vector of raw predictions, scores, or label probabilities).
New in version 1.4.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]),... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0),→˓(0.8, 1.0)])>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])...>>> evaluator = BinaryClassificationEvaluator()>>> evaluator.setRawPredictionCol("raw")BinaryClassificationEvaluator...>>> evaluator.evaluate(dataset)0.70...>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})0.83...>>> bce_path = temp_path + "/bce">>> evaluator.save(bce_path)>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)>>> str(evaluator2.getRawPredictionCol())'raw'>>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]),→˓x[1], x[2]),
(continues on next page)
3.3. ML 971
pyspark Documentation, Release master
(continued from previous page)
... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),
... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label",→˓"weight"])...>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol=→˓"weight")>>> evaluator.evaluate(dataset)0.70...>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})0.82...>>> evaluator.getNumBins()1000
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getLabelCol() Gets the value of labelCol or its default value.getMetricName() Gets the value of metricName or its default value.getNumBins() Gets the value of numBins or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getRawPredictionCol() Gets the value of rawPredictionCol or its default
value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isLargerBetter() Indicates whether the metric returned by
evaluate() should be maximized (True, de-fault) or minimized (False).
isSet(param) Checks whether a param is explicitly set by user.continues on next page
972 Chapter 3. API Reference
pyspark Documentation, Release master
Table 356 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setLabelCol(value) Sets the value of labelCol.setMetricName(value) Sets the value of metricName.setNumBins(value) Sets the value of numBins.setParams(self, \*[, rawPredictionCol, . . . ]) Sets params for binary classification evaluator.setRawPredictionCol(value) Sets the value of rawPredictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
labelColmetricNamenumBinsparams Returns all params ordered by name.rawPredictionColweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset, params=None)Evaluates the output with optional parameters.
New in version 1.4.0.
Parameters
dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions
params [dict, optional] an optional param map that overrides embedded params
Returns
float metric
3.3. ML 973
pyspark Documentation, Release master
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getLabelCol()Gets the value of labelCol or its default value.
getMetricName()Gets the value of metricName or its default value.
New in version 1.4.0.
getNumBins()Gets the value of numBins or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getRawPredictionCol()Gets the value of rawPredictionCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.
New in version 1.5.0.
isSet(param)Checks whether a param is explicitly set by user.
974 Chapter 3. API Reference
pyspark Documentation, Release master
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setLabelCol(value)Sets the value of labelCol.
setMetricName(value)Sets the value of metricName.
New in version 1.4.0.
setNumBins(value)Sets the value of numBins.
New in version 3.0.0.
setParams(self, \*, rawPredictionCol="rawPrediction", labelCol="label", metric-Name="areaUnderROC", weightCol=None, numBins=1000)
Sets params for binary classification evaluator.
New in version 1.4.0.
setRawPredictionCol(value)Sets the value of rawPredictionCol.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (areaUnderROC|areaUnderPR)')
numBins = Param(parent='undefined', name='numBins', doc='Number of bins to down-sample the curves (ROC curve, PR curve) in area computation. If 0, no down-sampling will occur. Must be >= 0.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
3.3. ML 975
pyspark Documentation, Release master
RegressionEvaluator
class pyspark.ml.evaluation.RegressionEvaluator(*args, **kwargs)Evaluator for Regression, which expects input columns prediction, label and an optional weight column.
New in version 1.4.0.
Examples
>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])...>>> evaluator = RegressionEvaluator()>>> evaluator.setPredictionCol("raw")RegressionEvaluator...>>> evaluator.evaluate(dataset)2.842...>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})0.993...>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})2.649...>>> re_path = temp_path + "/re">>> evaluator.save(re_path)>>> evaluator2 = RegressionEvaluator.load(re_path)>>> str(evaluator2.getPredictionCol())'raw'>>> scoreAndLabelsAndWeight = [(-28.98343821, -27.0, 1.0), (20.21491975, 21.5, 0.→˓8),... (-25.98418959, -22.0, 1.0), (30.69731842, 33.0, 0.6), (74.69283752, 71.0, 0.→˓2)]>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label",→˓"weight"])...>>> evaluator = RegressionEvaluator(predictionCol="raw", weightCol="weight")>>> evaluator.evaluate(dataset)2.740...>>> evaluator.getThroughOrigin()False
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
continues on next page
976 Chapter 3. API Reference
pyspark Documentation, Release master
Table 358 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and
user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getLabelCol() Gets the value of labelCol or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getThroughOrigin() Gets the value of throughOrigin or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isLargerBetter() Indicates whether the metric returned by
evaluate() should be maximized (True, de-fault) or minimized (False).
isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setLabelCol(value) Sets the value of labelCol.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for regression evaluator.setPredictionCol(value) Sets the value of predictionCol.setThroughOrigin(value) Sets the value of throughOrigin.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
labelColmetricNameparams Returns all params ordered by name.predictionColthroughOriginweightCol
3.3. ML 977
pyspark Documentation, Release master
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset, params=None)Evaluates the output with optional parameters.
New in version 1.4.0.
Parameters
dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions
params [dict, optional] an optional param map that overrides embedded params
Returns
float metric
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getLabelCol()Gets the value of labelCol or its default value.
getMetricName()Gets the value of metricName or its default value.
New in version 1.4.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
978 Chapter 3. API Reference
pyspark Documentation, Release master
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getThroughOrigin()Gets the value of throughOrigin or its default value.
New in version 3.0.0.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.
New in version 1.5.0.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setLabelCol(value)Sets the value of labelCol.
setMetricName(value)Sets the value of metricName.
New in version 1.4.0.
setParams(self, \*, predictionCol="prediction", labelCol="label", metricName="rmse", weight-Col=None, throughOrigin=False)
Sets params for regression evaluator.
New in version 1.4.0.
setPredictionCol(value)Sets the value of predictionCol.
setThroughOrigin(value)Sets the value of throughOrigin.
New in version 3.0.0.
3.3. ML 979
pyspark Documentation, Release master
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation - one of:\n rmse - root mean squared error (default)\n mse - mean squared error\n r2 - r^2 metric\n mae - mean absolute error\n var - explained variance.')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
throughOrigin = Param(parent='undefined', name='throughOrigin', doc='whether the regression is through the origin.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
MulticlassClassificationEvaluator
class pyspark.ml.evaluation.MulticlassClassificationEvaluator(*args, **kwargs)Evaluator for Multiclass Classification, which expects input columns: prediction, label, weight (optional) andprobabilityCol (only for logLoss).
New in version 1.5.0.
Examples
>>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])>>> evaluator = MulticlassClassificationEvaluator()>>> evaluator.setPredictionCol("prediction")MulticlassClassificationEvaluator...>>> evaluator.evaluate(dataset)0.66...>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})0.66...>>> evaluator.evaluate(dataset, {evaluator.metricName: "truePositiveRateByLabel",... evaluator.metricLabel: 1.0})0.75...>>> evaluator.setMetricName("hammingLoss")MulticlassClassificationEvaluator...>>> evaluator.evaluate(dataset)0.33...>>> mce_path = temp_path + "/mce">>> evaluator.save(mce_path)>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)>>> str(evaluator2.getPredictionCol())'prediction'>>> scoreAndLabelsAndWeight = [(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0),
(continues on next page)
980 Chapter 3. API Reference
pyspark Documentation, Release master
(continued from previous page)
... (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)]>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["prediction", "label→˓", "weight"])>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",... weightCol="weight")>>> evaluator.evaluate(dataset)0.66...>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})0.66...>>> predictionAndLabelsWithProbabilities = [... (1.0, 1.0, 1.0, [0.1, 0.8, 0.1]), (0.0, 2.0, 1.0, [0.9, 0.05, 0.05]),... (0.0, 0.0, 1.0, [0.8, 0.2, 0.0]), (1.0, 1.0, 1.0, [0.3, 0.65, 0.05])]>>> dataset = spark.createDataFrame(predictionAndLabelsWithProbabilities, [→˓"prediction",... "label", "weight", "probability"])>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",... probabilityCol="probability")>>> evaluator.setMetricName("logLoss")MulticlassClassificationEvaluator...>>> evaluator.evaluate(dataset)0.9682...
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getBeta() Gets the value of beta or its default value.getEps() Gets the value of eps or its default value.getLabelCol() Gets the value of labelCol or its default value.getMetricLabel() Gets the value of metricLabel or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.
continues on next page
3.3. ML 981
pyspark Documentation, Release master
Table 360 – continued from previous pagegetWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isLargerBetter() Indicates whether the metric returned by
evaluate() should be maximized (True, de-fault) or minimized (False).
isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBeta(value) Sets the value of beta.setEps(value) Sets the value of eps.setLabelCol(value) Sets the value of labelCol.setMetricLabel(value) Sets the value of metricLabel.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for multiclass classification evaluator.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
Attributes
betaepslabelColmetricLabelmetricNameparams Returns all params ordered by name.predictionColprobabilityColweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
982 Chapter 3. API Reference
pyspark Documentation, Release master
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset, params=None)Evaluates the output with optional parameters.
New in version 1.4.0.
Parameters
dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions
params [dict, optional] an optional param map that overrides embedded params
Returns
float metric
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getBeta()Gets the value of beta or its default value.
New in version 3.0.0.
getEps()Gets the value of eps or its default value.
New in version 3.0.0.
getLabelCol()Gets the value of labelCol or its default value.
getMetricLabel()Gets the value of metricLabel or its default value.
New in version 3.0.0.
getMetricName()Gets the value of metricName or its default value.
New in version 1.5.0.
3.3. ML 983
pyspark Documentation, Release master
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getProbabilityCol()Gets the value of probabilityCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.
New in version 1.5.0.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setBeta(value)Sets the value of beta.
New in version 3.0.0.
setEps(value)Sets the value of eps.
New in version 3.0.0.
setLabelCol(value)Sets the value of labelCol.
setMetricLabel(value)Sets the value of metricLabel.
New in version 3.0.0.
984 Chapter 3. API Reference
pyspark Documentation, Release master
setMetricName(value)Sets the value of metricName.
New in version 1.5.0.
setParams(self, \*, predictionCol="prediction", labelCol="label", metricName="f1", weight-Col=None, metricLabel=0.0, beta=1.0, probabilityCol="probability", eps=1e-15)
Sets params for multiclass classification evaluator.
New in version 1.5.0.
setPredictionCol(value)Sets the value of predictionCol.
setProbabilityCol(value)Sets the value of probabilityCol.
New in version 3.0.0.
setWeightCol(value)Sets the value of weightCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
beta = Param(parent='undefined', name='beta', doc='The beta value used in weightedFMeasure|fMeasureByLabel. Must be > 0. The default value is 1.')
eps = Param(parent='undefined', name='eps', doc='log-loss is undefined for p=0 or p=1, so probabilities are clipped to max(eps, min(1 - eps, p)). Must be in range (0, 0.5). The default value is 1e-15.')
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
metricLabel = Param(parent='undefined', name='metricLabel', doc='The class whose metric will be computed in truePositiveRateByLabel|falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel. Must be >= 0. The default value is 0.')
metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate| weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| logLoss|hammingLoss)')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
MultilabelClassificationEvaluator
class pyspark.ml.evaluation.MultilabelClassificationEvaluator(*args, **kwargs)Evaluator for Multilabel Classification, which expects two input columns: prediction and label.
New in version 3.0.0.
3.3. ML 985
pyspark Documentation, Release master
Notes
Experimental
Examples
>>> scoreAndLabels = [([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])...>>> evaluator = MultilabelClassificationEvaluator()>>> evaluator.setPredictionCol("prediction")MultilabelClassificationEvaluator...>>> evaluator.evaluate(dataset)0.63...>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})0.54...>>> mlce_path = temp_path + "/mlce">>> evaluator.save(mlce_path)>>> evaluator2 = MultilabelClassificationEvaluator.load(mlce_path)>>> str(evaluator2.getPredictionCol())'prediction'
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getLabelCol() Gets the value of labelCol or its default value.getMetricLabel() Gets the value of metricLabel or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.hasDefault(param) Checks whether a param has a default value.
continues on next page
986 Chapter 3. API Reference
pyspark Documentation, Release master
Table 362 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isLargerBetter() Indicates whether the metric returned by
evaluate() should be maximized (True, de-fault) or minimized (False).
isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setLabelCol(value) Sets the value of labelCol.setMetricLabel(value) Sets the value of metricLabel.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for multilabel classification evaluator.setPredictionCol(value) Sets the value of predictionCol.write() Returns an MLWriter instance for this ML instance.
Attributes
labelColmetricLabelmetricNameparams Returns all params ordered by name.predictionCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset, params=None)Evaluates the output with optional parameters.
New in version 1.4.0.
Parameters
3.3. ML 987
pyspark Documentation, Release master
dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions
params [dict, optional] an optional param map that overrides embedded params
Returns
float metric
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
dict merged param map
getLabelCol()Gets the value of labelCol or its default value.
getMetricLabel()Gets the value of metricLabel or its default value.
New in version 3.0.0.
getMetricName()Gets the value of metricName or its default value.
New in version 3.0.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.
988 Chapter 3. API Reference
pyspark Documentation, Release master
New in version 1.5.0.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setLabelCol(value)Sets the value of labelCol.
New in version 3.0.0.
setMetricLabel(value)Sets the value of metricLabel.
New in version 3.0.0.
setMetricName(value)Sets the value of metricName.
New in version 3.0.0.
setParams(self, \*, predictionCol="prediction", labelCol="label", metricName="f1Measure", metri-cLabel=0.0)
Sets params for multilabel classification evaluator.
New in version 3.0.0.
setPredictionCol(value)Sets the value of predictionCol.
New in version 3.0.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')
metricLabel = Param(parent='undefined', name='metricLabel', doc='The class whose metric will be computed in precisionByLabel|recallByLabel|f1MeasureByLabel. Must be >= 0. The default value is 0.')
metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|microRecall|microF1Measure)')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
3.3. ML 989
pyspark Documentation, Release master
ClusteringEvaluator
class pyspark.ml.evaluation.ClusteringEvaluator(*args, **kwargs)Evaluator for Clustering results, which expects two input columns: prediction and features. The metric computesthe Silhouette measure using the squared Euclidean distance.
The Silhouette is a measure for the validation of the consistency within clusters. It ranges between 1 and -1,where a value close to 1 means that the points in a cluster are close to the other points in the same cluster andfar from the points of the other clusters.
New in version 2.3.0.
Examples
>>> from pyspark.ml.linalg import Vectors>>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])>>> dataset = spark.createDataFrame(featureAndPredictions, ["features",→˓"prediction"])...>>> evaluator = ClusteringEvaluator()>>> evaluator.setPredictionCol("prediction")ClusteringEvaluator...>>> evaluator.evaluate(dataset)0.9079...>>> featureAndPredictionsWithWeight = map(lambda x: (Vectors.dense(x[0]), x[1],→˓x[2]),... [([0.0, 0.5], 0.0, 2.5), ([0.5, 0.0], 0.0, 2.5), ([10.0, 11.0], 1.0, 2.5),... ([10.5, 11.5], 1.0, 2.5), ([1.0, 1.0], 0.0, 2.5), ([8.0, 6.0], 1.0, 2.5)])>>> dataset = spark.createDataFrame(... featureAndPredictionsWithWeight, ["features", "prediction", "weight"])>>> evaluator = ClusteringEvaluator()>>> evaluator.setPredictionCol("prediction")ClusteringEvaluator...>>> evaluator.setWeightCol("weight")ClusteringEvaluator...>>> evaluator.evaluate(dataset)0.9079...>>> ce_path = temp_path + "/ce">>> evaluator.save(ce_path)>>> evaluator2 = ClusteringEvaluator.load(ce_path)>>> str(evaluator2.getPredictionCol())'prediction'
Methods
clear(param) Clears a param from the param map if it has beenexplicitly set.
copy([extra]) Creates a copy of this instance with the same uid andsome extra params.
evaluate(dataset[, params]) Evaluates the output with optional parameters.continues on next page
990 Chapter 3. API Reference
pyspark Documentation, Release master
Table 364 – continued from previous pageexplainParam(param) Explains a single param and returns its name, doc,
and optional default value and user-supplied value ina string.
explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.
extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.
getDistanceMeasure() Gets the value of distanceMeasuregetFeaturesCol() Gets the value of featuresCol or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param
map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a
given (string) name.isDefined(param) Checks whether a param is explicitly set by user or
has a default value.isLargerBetter() Indicates whether the metric returned by
evaluate() should be maximized (True, de-fault) or minimized (False).
isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut
of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of
‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDistanceMeasure(value) Sets the value of distanceMeasure.setFeaturesCol(value) Sets the value of featuresCol.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for clustering evaluator.setPredictionCol(value) Sets the value of predictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.
3.3. ML 991
pyspark Documentation, Release master
Attributes
distanceMeasurefeaturesColmetricNameparams Returns all params ordered by name.predictionColweightCol
Methods Documentation
clear(param)Clears a param from the param map if it has been explicitly set.
copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.
Parameters
extra [dict, optional] Extra parameters to copy to the new instance
Returns
JavaParams Copy of this instance
evaluate(dataset, params=None)Evaluates the output with optional parameters.
New in version 1.4.0.
Parameters
dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions
params [dict, optional] an optional param map that overrides embedded params
Returns
float metric
explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.
explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.
extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.
Parameters
extra [dict, optional] extra param values
Returns
992 Chapter 3. API Reference
pyspark Documentation, Release master
dict merged param map
getDistanceMeasure()Gets the value of distanceMeasure
New in version 2.4.0.
getFeaturesCol()Gets the value of featuresCol or its default value.
getMetricName()Gets the value of metricName or its default value.
New in version 2.3.0.
getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.
getParam(paramName)Gets a param by its name.
getPredictionCol()Gets the value of predictionCol or its default value.
getWeightCol()Gets the value of weightCol or its default value.
hasDefault(param)Checks whether a param has a default value.
hasParam(paramName)Tests whether this instance contains a param with a given (string) name.
isDefined(param)Checks whether a param is explicitly set by user or has a default value.
isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.
New in version 1.5.0.
isSet(param)Checks whether a param is explicitly set by user.
classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).
classmethod read()Returns an MLReader instance for this class.
save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
set(param, value)Sets a parameter in the embedded param map.
setDistanceMeasure(value)Sets the value of distanceMeasure.
New in version 2.4.0.
setFeaturesCol(value)Sets the value of featuresCol.
3.3. ML 993
pyspark Documentation, Release master
setMetricName(value)Sets the value of metricName.
New in version 2.3.0.
setParams(self, \*, predictionCol="prediction", featuresCol="features", metricName="silhouette",distanceMeasure="squaredEuclidean", weightCol=None)
Sets params for clustering evaluator.
New in version 2.3.0.
setPredictionCol(value)Sets the value of predictionCol.
setWeightCol(value)Sets the value of weightCol.
New in version 3.1.0.
write()Returns an MLWriter instance for this ML instance.
Attributes Documentation
distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="The distance measure. Supported options: 'squaredEuclidean' and 'cosine'.")
featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')
metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (silhouette)')
paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.
predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')
weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')
RankingEvaluator
class pyspark.ml.evaluation.RankingEvaluator(*args, **kwargs)Evaluator for Ranking, which expects two input columns: prediction and label.
New in version 3.0.0.
Notes
Experimental
994 Chapter 3. API Reference