1448
pyspark Documentation Release master Author Dec 17, 2020

pyspark Documentation

  • Upload
    others

  • View
    264

  • Download
    1

Embed Size (px)

Citation preview

Page 1: pyspark Documentation

pyspark DocumentationRelease master

Author

Dec 17, 2020

Page 2: pyspark Documentation
Page 3: pyspark Documentation

CONTENTS

1 Getting Started 31.1 Installation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31.2 Quickstart . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5

2 User Guide 152.1 Apache Arrow in PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 152.2 Python Package Management . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 24

3 API Reference 293.1 Spark SQL . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 293.2 Structured Streaming . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2543.3 ML . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2803.4 Spark Streaming . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10233.5 MLlib . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10413.6 Spark Core . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 11923.7 Resource Management . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1271

4 Development 12774.1 Contributing to PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12774.2 Testing PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12794.3 Debugging PySpark . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12794.4 Setting up IDEs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1284

5 Migration Guide 12875.1 Upgrading from PySpark 2.4 to 3.0 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12875.2 Upgrading from PySpark 2.3 to 2.4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.3 Upgrading from PySpark 2.3.0 to 2.3.1 and above . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.4 Upgrading from PySpark 2.2 to 2.3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.5 Upgrading from PySpark 1.4 to 1.5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12885.6 Upgrading from PySpark 1.0-1.2 to 1.3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1289

Bibliography 1291

Index 1293

i

Page 4: pyspark Documentation

ii

Page 5: pyspark Documentation

pyspark Documentation, Release master

Live Notebook | GitHub | Issues | Examples | Community

PySpark is an interface for Apache Spark in Python. It not only allows you to write Spark applications using PythonAPIs, but also provides the PySpark shell for interactively analyzing your data in a distributed environment. PySparksupports most of Spark’s features such as Spark SQL, DataFrame, Streaming, MLlib (Machine Learning) and SparkCore.

Spark SQL and DataFrame

Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrameand can also act as distributed SQL query engine.

Streaming

Running on top of Spark, the streaming feature in Apache Spark enables powerful interactive and analytical applica-tions across both streaming and historical data, while inheriting Spark’s ease of use and fault tolerance characteristics.

MLlib

Built on top of Spark, MLlib is a scalable machine learning library that provides a uniform set of high-level APIs thathelp users create and tune practical machine learning pipelines.

Spark Core

Spark Core is the underlying general execution engine for the Spark platform that all other functionality is built on topof. It provides an RDD (Resilient Distributed Dataset) and in-memory computing capabilities.

CONTENTS 1

Page 6: pyspark Documentation

pyspark Documentation, Release master

2 CONTENTS

Page 7: pyspark Documentation

CHAPTER

ONE

GETTING STARTED

This page summarizes the basic steps required to setup and get started with PySpark.

1.1 Installation

PySpark is included in the official releases of Spark available in the Apache Spark website. For Python users, PySparkalso provides pip installation from PyPI. This is usually for local usage or as a client to connect to a cluster insteadof setting up a cluster itself.

This page includes instructions for installing PySpark by using pip, Conda, downloading manually, and building fromthe source.

1.1.1 Python Version Supported

Python 3.6 and above.

1.1.2 Using PyPI

PySpark installation using PyPI is as follows:

pip install pyspark

If you want to install extra dependencies for a specific component, you can install it as below:

pip install pyspark[sql]

For PySpark with/without a specific Hadoop version, you can install it by using HADOOP_VERSION environmentvariables as below:

HADOOP_VERSION=2.7 pip install pyspark

The default distribution uses Hadoop 3.2 and Hive 2.3. If users specify different versions of Hadoop, the pip installationautomatically downloads a different version and use it in PySpark. Downloading it can take a while depending on thenetwork and the mirror chosen. PYSPARK_RELEASE_MIRROR can be set to manually choose the mirror for fasterdownloading.

PYSPARK_RELEASE_MIRROR=http://mirror.apache-kr.org HADOOP_VERSION=2.7 pip install

It is recommended to use -v option in pip to track the installation and download status.

3

Page 8: pyspark Documentation

pyspark Documentation, Release master

HADOOP_VERSION=2.7 pip install pyspark -v

Supported values in HADOOP_VERSION are:

• without: Spark pre-built with user-provided Apache Hadoop

• 2.7: Spark pre-built for Apache Hadoop 2.7

• 3.2: Spark pre-built for Apache Hadoop 3.2 and later (default)

Note that this installation way of PySpark with/without a specific Hadoop version is experimental. It can change or beremoved between minor releases.

1.1.3 Using Conda

Conda is an open-source package management and environment management system which is a part of the Anacondadistribution. It is both cross-platform and language agnostic. In practice, Conda can replace both pip and virtualenv.

Create new virtual environment from your terminal as shown below:

conda create -n pyspark_env

After the virtual environment is created, it should be visible under the list of Conda environments which can be seenusing the following command:

conda env list

Now activate the newly created environment with the following command:

conda activate pyspark_env

You can install pyspark by Using PyPI to install PySpark in the newly created environment, for example as below. Itwill install PySpark under the new virtual environment pyspark_env created above.

pip install pyspark

Alternatively, you can install PySpark from Conda itself as below:

conda install pyspark

However, note that PySpark at Conda is not necessarily synced with PySpark release cycle because it is maintained bythe community separately.

1.1.4 Manually Downloading

PySpark is included in the distributions available at the Apache Spark website. You can download a distribution youwant from the site. After that, uncompress the tar file into the directory where you want to install Spark, for example,as below:

tar xzvf spark-3.0.0-bin-hadoop2.7.tgz

Ensure the SPARK_HOME environment variable points to the directory where the tar file has been extracted. UpdatePYTHONPATH environment variable such that it can find the PySpark and Py4J under SPARK_HOME/python/lib.One example of doing this is shown below:

4 Chapter 1. Getting Started

Page 9: pyspark Documentation

pyspark Documentation, Release master

cd spark-3.0.0-bin-hadoop2.7export SPARK_HOME=`pwd`export PYTHONPATH=$(ZIPS=("$SPARK_HOME"/python/lib/*.zip); IFS=:; echo "${ZIPS[*]}"):→˓$PYTHONPATH

1.1.5 Installing from Source

To install PySpark from source, refer to Building Spark.

1.1.6 Dependencies

Package Minimum supported version Notepandas 0.23.2 Optional for SQLNumPy 1.7 Required for MLpyarrow 1.0.0 Optional for SQLPy4J 0.10.9 Required

Note that PySpark requires Java 8 or later with JAVA_HOME properly set. If using JDK 11, set -Dio.netty.tryReflectionSetAccessible=true for Arrow related features and refer to Downloading.

1.2 Quickstart

This is a short introduction and quickstart for the PySpark DataFrame API. PySpark DataFrames are lazily evaluated.They are implemented on top of RDDs. When Spark transforms data, it does not immediately compute the transforma-tion but plans how to compute later. When actions such as collect() are explicitly called, the computation starts.This notebook shows the basic usages of the DataFrame, geared mainly for new users. You can run the latest versionof these examples by yourself on a live notebook here.

There is also other useful information in Apache Spark documentation site, see the latest version of Spark SQL andDataFrames, RDD Programming Guide, Structured Streaming Programming Guide, Spark Streaming ProgrammingGuide and Machine Learning Library (MLlib) Guide.

PySpark applications start with initializing SparkSession which is the entry point of PySpark as below. In case ofrunning it in PySpark shell via pyspark executable, the shell automatically creates the session in the variable spark forusers.

[1]: from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

1.2. Quickstart 5

Page 10: pyspark Documentation

pyspark Documentation, Release master

1.2.1 DataFrame Creation

A PySpark DataFrame can be created via pyspark.sql.SparkSession.createDataFrame typically bypassing a list of lists, tuples, dictionaries and pyspark.sql.Rows, a pandas DataFrame and an RDD consistingof such a list. pyspark.sql.SparkSession.createDataFrame takes the schema argument to specify theschema of the DataFrame. When it is omitted, PySpark infers the corresponding schema by taking a sample from thedata.

Firstly, you can create a PySpark DataFrame from a list of rows

[2]: from datetime import datetime, dateimport pandas as pdfrom pyspark.sql import Row

df = spark.createDataFrame([Row(a=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0)),Row(a=2, b=3., c='string2', d=date(2000, 2, 1), e=datetime(2000, 1, 2, 12, 0)),Row(a=4, b=5., c='string3', d=date(2000, 3, 1), e=datetime(2000, 1, 3, 12, 0))

])df

[2]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

Create a PySpark DataFrame with an explicit schema.

[3]: df = spark.createDataFrame([(1, 2., 'string1', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),(2, 3., 'string2', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),(3, 4., 'string3', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))

], schema='a long, b double, c string, d date, e timestamp')df

[3]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

Create a PySpark DataFrame from a pandas DataFrame

[4]: pandas_df = pd.DataFrame({'a': [1, 2, 3],'b': [2., 3., 4.],'c': ['string1', 'string2', 'string3'],'d': [date(2000, 1, 1), date(2000, 2, 1), date(2000, 3, 1)],'e': [datetime(2000, 1, 1, 12, 0), datetime(2000, 1, 2, 12, 0), datetime(2000, 1,

→˓3, 12, 0)]})df = spark.createDataFrame(pandas_df)df

[4]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

Create a PySpark DataFrame from an RDD consisting of a list of tuples.

[5]: rdd = spark.sparkContext.parallelize([(1, 2., 'string1', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),(2, 3., 'string2', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),(3, 4., 'string3', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))

])df = spark.createDataFrame(rdd, schema=['a', 'b', 'c', 'd', 'e'])df

6 Chapter 1. Getting Started

Page 11: pyspark Documentation

pyspark Documentation, Release master

[5]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

The DataFrames created above all have the same results and schema.

[6]: # All DataFrames above result same.df.show()df.printSchema()

+---+---+-------+----------+-------------------+| a| b| c| d| e|+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|| 2|3.0|string2|2000-02-01|2000-01-02 12:00:00|| 3|4.0|string3|2000-03-01|2000-01-03 12:00:00|+---+---+-------+----------+-------------------+

root|-- a: long (nullable = true)|-- b: double (nullable = true)|-- c: string (nullable = true)|-- d: date (nullable = true)|-- e: timestamp (nullable = true)

1.2.2 Viewing Data

The top rows of a DataFrame can be displayed using DataFrame.show().

[7]: df.show(1)

+---+---+-------+----------+-------------------+| a| b| c| d| e|+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|+---+---+-------+----------+-------------------+only showing top 1 row

Alternatively, you can enable spark.sql.repl.eagerEval.enabled configuration for the eager evaluationof PySpark DataFrame in notebooks such as Jupyter. The number of rows to show can be controlled via spark.sql.repl.eagerEval.maxNumRows configuration.

[8]: spark.conf.set('spark.sql.repl.eagerEval.enabled', True)df

[8]: DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

The rows can also be shown vertically. This is useful when rows are too long to show horizontally.

[9]: df.show(1, vertical=True)

-RECORD 0------------------a | 1b | 2.0c | string1d | 2000-01-01e | 2000-01-01 12:00:00

(continues on next page)

1.2. Quickstart 7

Page 12: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

only showing top 1 row

You can see the DataFrame’s schema and column names as follows:

[10]: df.columns

[10]: ['a', 'b', 'c', 'd', 'e']

[11]: df.printSchema()

root|-- a: long (nullable = true)|-- b: double (nullable = true)|-- c: string (nullable = true)|-- d: date (nullable = true)|-- e: timestamp (nullable = true)

Show the summary of the DataFrame

[12]: df.select("a", "b", "c").describe().show()

+-------+---+---+-------+|summary| a| b| c|+-------+---+---+-------+| count| 3| 3| 3|| mean|2.0|3.0| null|| stddev|1.0|1.0| null|| min| 1|2.0|string1|| max| 3|4.0|string3|+-------+---+---+-------+

DataFrame.collect() collects the distributed data to the driver side as the local data in Python. Note that thiscan throw an out-of-memory error when the dataset is too large to fit in the driver side because it collects all the datafrom executors to the driver side.

[13]: df.collect()

[13]: [Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000,→˓1, 1, 12, 0)),Row(a=2, b=3.0, c='string2', d=datetime.date(2000, 2, 1), e=datetime.datetime(2000,→˓1, 2, 12, 0)),Row(a=3, b=4.0, c='string3', d=datetime.date(2000, 3, 1), e=datetime.datetime(2000,→˓1, 3, 12, 0))]

In order to avoid throwing an out-of-memory exception, use DataFrame.take() or DataFrame.tail().

[14]: df.take(1)

[14]: [Row(a=1, b=2.0, c='string1', d=datetime.date(2000, 1, 1), e=datetime.datetime(2000,→˓1, 1, 12, 0))]

PySpark DataFrame also provides the conversion back to a pandas DataFrame to leverage pandas APIs. Note thattoPandas also collects all data into the driver side that can easily cause an out-of-memory-error when the data is toolarge to fit into the driver side.

8 Chapter 1. Getting Started

Page 13: pyspark Documentation

pyspark Documentation, Release master

[15]: df.toPandas()

[15]: a b c d e0 1 2.0 string1 2000-01-01 2000-01-01 12:00:001 2 3.0 string2 2000-02-01 2000-01-02 12:00:002 3 4.0 string3 2000-03-01 2000-01-03 12:00:00

1.2.3 Selecting and Accessing Data

PySpark DataFrame is lazily evaluated and simply selecting a column does not trigger the computation but it returns aColumn instance.

[16]: df.a

[16]: Column<b'a'>

In fact, most of column-wise operations return Columns.

[17]: from pyspark.sql import Columnfrom pyspark.sql.functions import upper

type(df.c) == type(upper(df.c)) == type(df.c.isNull())

[17]: True

These Columns can be used to select the columns from a DataFrame. For example, DataFrame.select() takesthe Column instances that returns another DataFrame.

[18]: df.select(df.c).show()

+-------+| c|+-------+|string1||string2||string3|+-------+

Assign new Column instance.

[19]: df.withColumn('upper_c', upper(df.c)).show()

+---+---+-------+----------+-------------------+-------+| a| b| c| d| e|upper_c|+---+---+-------+----------+-------------------+-------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|STRING1|| 2|3.0|string2|2000-02-01|2000-01-02 12:00:00|STRING2|| 3|4.0|string3|2000-03-01|2000-01-03 12:00:00|STRING3|+---+---+-------+----------+-------------------+-------+

To select a subset of rows, use DataFrame.filter().

[20]: df.filter(df.a == 1).show()

+---+---+-------+----------+-------------------+| a| b| c| d| e|

(continues on next page)

1.2. Quickstart 9

Page 14: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|+---+---+-------+----------+-------------------+

1.2.4 Applying a Function

PySpark supports various UDFs and APIs to allow users to execute Python native functions. See also the latest PandasUDFs and Pandas Function APIs. For instance, the example below allows users to directly use the APIs in a pandasSeries within Python native function.

[21]: import pandasfrom pyspark.sql.functions import pandas_udf

@pandas_udf('long')def pandas_plus_one(series: pd.Series) -> pd.Series:

# Simply plus one by using pandas Series.return series + 1

df.select(pandas_plus_one(df.a)).show()

+------------------+|pandas_plus_one(a)|+------------------+| 2|| 3|| 4|+------------------+

Another example is DataFrame.mapInPandas which allows users directly use the APIs in a pandas DataFramewithout any restrictions such as the result length.

[22]: def pandas_filter_func(iterator):for pandas_df in iterator:

yield pandas_df[pandas_df.a == 1]

df.mapInPandas(pandas_filter_func, schema=df.schema).show()

+---+---+-------+----------+-------------------+| a| b| c| d| e|+---+---+-------+----------+-------------------+| 1|2.0|string1|2000-01-01|2000-01-01 12:00:00|+---+---+-------+----------+-------------------+

10 Chapter 1. Getting Started

Page 15: pyspark Documentation

pyspark Documentation, Release master

1.2.5 Grouping Data

PySpark DataFrame also provides a way of handling grouped data by using the common approach, split-apply-combinestrategy. It groups the data by a certain condition applies a function to each group and then combines them back to theDataFrame.

[23]: df = spark.createDataFrame([['red', 'banana', 1, 10], ['blue', 'banana', 2, 20], ['red', 'carrot', 3, 30],['blue', 'grape', 4, 40], ['red', 'carrot', 5, 50], ['black', 'carrot', 6, 60],['red', 'banana', 7, 70], ['red', 'grape', 8, 80]], schema=['color', 'fruit', 'v1

→˓', 'v2'])df.show()

+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+

Grouping and then applying the avg() function to the resulting groups.

[24]: df.groupby('color').avg().show()

+-----+-------+-------+|color|avg(v1)|avg(v2)|+-----+-------+-------+| red| 4.8| 48.0||black| 6.0| 60.0|| blue| 3.0| 30.0|+-----+-------+-------+

You can also apply a Python native function against each group by using pandas APIs.

[25]: def plus_mean(pandas_df):return pandas_df.assign(v1=pandas_df.v1 - pandas_df.v1.mean())

df.groupby('color').applyInPandas(plus_mean, schema=df.schema).show()

+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| -3| 10|| red|carrot| -1| 30|| red|carrot| 0| 50|| red|banana| 2| 70|| red| grape| 3| 80||black|carrot| 0| 60|| blue|banana| -1| 20|| blue| grape| 1| 40|

(continues on next page)

1.2. Quickstart 11

Page 16: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+-----+------+---+---+

Co-grouping and applying a function.

[26]: df1 = spark.createDataFrame([(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],('time', 'id', 'v1'))

df2 = spark.createDataFrame([(20000101, 1, 'x'), (20000101, 2, 'y')],('time', 'id', 'v2'))

def asof_join(l, r):return pd.merge_asof(l, r, on='time', by='id')

df1.groupby('id').cogroup(df2.groupby('id')).applyInPandas(asof_join, schema='time int, id int, v1 double, v2 string').show()

+--------+---+---+---+| time| id| v1| v2|+--------+---+---+---+|20000101| 1|1.0| x||20000102| 1|3.0| x||20000101| 2|2.0| y||20000102| 2|4.0| y|+--------+---+---+---+

1.2.6 Getting Data in/out

CSV is straightforward and easy to use. Parquet and ORC are efficient and compact file formats to read and writefaster.

There are many other data sources available in PySpark such as JDBC, text, binaryFile, Avro, etc. See also the latestSpark SQL, DataFrames and Datasets Guide in Apache Spark documentation.

CSV

[27]: df.write.csv('foo.csv', header=True)spark.read.csv('foo.csv', header=True).show()

+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+

12 Chapter 1. Getting Started

Page 17: pyspark Documentation

pyspark Documentation, Release master

Parquet

[28]: df.write.parquet('bar.parquet')spark.read.parquet('bar.parquet').show()

+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+

ORC

[29]: df.write.orc('zoo.orc')spark.read.orc('zoo.orc').show()

+-----+------+---+---+|color| fruit| v1| v2|+-----+------+---+---+| red|banana| 1| 10|| blue|banana| 2| 20|| red|carrot| 3| 30|| blue| grape| 4| 40|| red|carrot| 5| 50||black|carrot| 6| 60|| red|banana| 7| 70|| red| grape| 8| 80|+-----+------+---+---+

1.2.7 Working with SQL

DataFrame and Spark SQL share the same execution engine so they can be interchangeably used seamlessly. Forexample, you can register the DataFrame as a table and run a SQL easily as below:

[30]: df.createOrReplaceTempView("tableA")spark.sql("SELECT count(*) from tableA").show()

+--------+|count(1)|+--------+| 8|+--------+

In addition, UDFs can be registered and invoked in SQL out of the box:

1.2. Quickstart 13

Page 18: pyspark Documentation

pyspark Documentation, Release master

[31]: @pandas_udf("integer")def add_one(s: pd.Series) -> pd.Series:

return s + 1

spark.udf.register("add_one", add_one)spark.sql("SELECT add_one(v1) FROM tableA").show()

+-----------+|add_one(v1)|+-----------+| 2|| 3|| 4|| 5|| 6|| 7|| 8|| 9|+-----------+

These SQL expressions can directly be mixed and used as PySpark columns.

[32]: from pyspark.sql.functions import expr

df.selectExpr('add_one(v1)').show()df.select(expr('count(*)') > 0).show()

+-----------+|add_one(v1)|+-----------+| 2|| 3|| 4|| 5|| 6|| 7|| 8|| 9|+-----------+

+--------------+|(count(1) > 0)|+--------------+| true|+--------------+

14 Chapter 1. Getting Started

Page 19: pyspark Documentation

CHAPTER

TWO

USER GUIDE

2.1 Apache Arrow in PySpark

Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer data between JVMand Python processes. This currently is most beneficial to Python users that work with Pandas/NumPy data. Its usageis not automatic and might require some minor changes to configuration or code to take full advantage and ensurecompatibility. This guide will give a high-level description of how to use Arrow in Spark and highlight any differenceswhen working with Arrow-enabled data.

2.1.1 Ensure PyArrow Installed

To use Apache Arrow in PySpark, the recommended version of PyArrow should be installed. If you install PySparkusing pip, then PyArrow can be brought in as an extra dependency of the SQL module with the command pipinstall pyspark[sql]. Otherwise, you must ensure that PyArrow is installed and available on all clusternodes. You can install using pip or conda from the conda-forge channel. See PyArrow installation for details.

2.1.2 Enabling for Conversion to/from Pandas

Arrow is available as an optimization when converting a Spark DataFrame to a Pandas DataFrame usingthe call DataFrame.toPandas() and when creating a Spark DataFrame from a Pandas DataFrame withSparkSession.createDataFrame(). To use Arrow when executing these calls, users need to first set theSpark configuration spark.sql.execution.arrow.pyspark.enabled to true. This is disabled by de-fault.

In addition, optimizations enabled by spark.sql.execution.arrow.pyspark.enabled could fallback au-tomatically to non-Arrow optimization implementation if an error occurs before the actual computation within Spark.This can be controlled by spark.sql.execution.arrow.pyspark.fallback.enabled.

import numpy as np # type: ignore[import]import pandas as pd # type: ignore[import]

# Enable Arrow-based columnar data transfersspark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# Generate a Pandas DataFramepdf = pd.DataFrame(np.random.rand(100, 3))

# Create a Spark DataFrame from a Pandas DataFrame using Arrowdf = spark.createDataFrame(pdf)

(continues on next page)

15

Page 20: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

# Convert the Spark DataFrame back to a Pandas DataFrame using Arrowresult_pdf = df.select("*").toPandas()

Using the above optimizations with Arrow will produce the same results as when Arrow is not enabled.

Note that even with Arrow, DataFrame.toPandas() results in the collection of all records in the DataFrameto the driver program and should be done on a small subset of the data. Not all Spark data types are currentlysupported and an error can be raised if a column has an unsupported type. If an error occurs during SparkSession.createDataFrame(), Spark will fall back to create the DataFrame without Arrow.

2.1.3 Pandas UDFs (a.k.a. Vectorized UDFs)

Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to workwith the data, which allows vectorized operations. A Pandas UDF is defined using the pandas_udf() as a decoratoror to wrap the function, and no additional configuration is required. A Pandas UDF behaves as a regular PySparkfunction API in general.

Before Spark 3.0, Pandas UDFs used to be defined with pyspark.sql.functions.PandasUDFType. FromSpark 3.0 with Python 3.6+, you can also use Python type hints. Using Python type hints is preferred and usingpyspark.sql.functions.PandasUDFType will be deprecated in the future release.

Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrameshould be used for its input or output type hint instead when the input or output column is of StructType. Thefollowing example shows a Pandas UDF which takes long column, string column and struct column, and outputs astruct column. It requires the function to specify the type hints of pandas.Series and pandas.DataFrame asbelow:

import pandas as pd

from pyspark.sql.functions import pandas_udf

@pandas_udf("col1 string, col2 long")def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:

s3['col2'] = s1 + s2.str.len()return s3

# Create a Spark DataFrame that has three columns including a struct column.df = spark.createDataFrame(

[[1, "a string", ("a nested string",)]],"long_col long, string_col string, struct_col struct<col1:string>")

df.printSchema()# root# |-- long_column: long (nullable = true)# |-- string_column: string (nullable = true)# |-- struct_column: struct (nullable = true)# | |-- col1: string (nullable = true)

df.select(func("long_col", "string_col", "struct_col")).printSchema()# |-- func(long_col, string_col, struct_col): struct (nullable = true)# | |-- col1: string (nullable = true)# | |-- col2: long (nullable = true)

In the following sections, it describes the combinations of the supported type hints. For simplicity, pandas.DataFrame variant is omitted.

16 Chapter 2. User Guide

Page 21: pyspark Documentation

pyspark Documentation, Release master

Series to Series

The type hint can be expressed as pandas.Series, . . . -> pandas.Series.

By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF where the givenfunction takes one or more pandas.Series and outputs one pandas.Series. The output of the function shouldalways be of the same length as the input. Internally, PySpark will execute a Pandas UDF by splitting columns intobatches and calling the function for each batch as a subset of the data, then concatenating the results together.

The following example shows how to create this Pandas UDF that computes the product of 2 columns.

import pandas as pd

from pyspark.sql.functions import col, pandas_udffrom pyspark.sql.types import LongType

# Declare the function and create the UDFdef multiply_func(a: pd.Series, b: pd.Series) -> pd.Series:

return a * b

multiply = pandas_udf(multiply_func, returnType=LongType())

# The function for a pandas_udf should be able to execute with local Pandas datax = pd.Series([1, 2, 3])print(multiply_func(x, x))# 0 1# 1 4# 2 9# dtype: int64

# Create a Spark DataFrame, 'spark' is an existing SparkSessiondf = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))

# Execute function as a Spark vectorized UDFdf.select(multiply(col("x"), col("x"))).show()# +-------------------+# |multiply_func(x, x)|# +-------------------+# | 1|# | 4|# | 9|# +-------------------+

For detailed usage, please see pandas_udf().

Iterator of Series to Iterator of Series

The type hint can be expressed as Iterator[pandas.Series] -> Iterator[pandas.Series].

By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF where the givenfunction takes an iterator of pandas.Series and outputs an iterator of pandas.Series. The length of the entireoutput from the function should be the same length of the entire input; therefore, it can prefetch the data from the inputiterator as long as the lengths are the same. In this case, the created Pandas UDF requires one input column when thePandas UDF is called. To use multiple input columns, a different type hint is required. See Iterator of Multiple Seriesto Iterator of Series.

It is also useful when the UDF execution requires initializing some states although internally it works identically asSeries to Series case. The pseudocode below illustrates the example.

2.1. Apache Arrow in PySpark 17

Page 22: pyspark Documentation

pyspark Documentation, Release master

@pandas_udf("long")def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:

# Do some expensive initialization with a statestate = very_expensive_initialization()for x in iterator:

# Use that state for whole iterator.yield calculate_with_state(x, state)

df.select(calculate("value")).show()

The following example shows how to create this Pandas UDF:

from typing import Iterator

import pandas as pd

from pyspark.sql.functions import pandas_udf

pdf = pd.DataFrame([1, 2, 3], columns=["x"])df = spark.createDataFrame(pdf)

# Declare the function and create the UDF@pandas_udf("long")def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:

for x in iterator:yield x + 1

df.select(plus_one("x")).show()# +-----------+# |plus_one(x)|# +-----------+# | 2|# | 3|# | 4|# +-----------+

For detailed usage, please see pandas_udf().

Iterator of Multiple Series to Iterator of Series

The type hint can be expressed as Iterator[Tuple[pandas.Series, ...]] -> Iterator[pandas.Series].

By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF where the givenfunction takes an iterator of a tuple of multiple pandas.Series and outputs an iterator of pandas.Series. Inthis case, the created pandas UDF requires multiple input columns as many as the series in the tuple when the PandasUDF is called. Otherwise, it has the same characteristics and restrictions as Iterator of Series to Iterator of Series case.

The following example shows how to create this Pandas UDF:

from typing import Iterator, Tuple

import pandas as pd

from pyspark.sql.functions import pandas_udf

(continues on next page)

18 Chapter 2. User Guide

Page 23: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

pdf = pd.DataFrame([1, 2, 3], columns=["x"])df = spark.createDataFrame(pdf)

# Declare the function and create the UDF@pandas_udf("long")def multiply_two_cols(

iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:for a, b in iterator:

yield a * b

df.select(multiply_two_cols("x", "x")).show()# +-----------------------+# |multiply_two_cols(x, x)|# +-----------------------+# | 1|# | 4|# | 9|# +-----------------------+

For detailed usage, please see pandas_udf().

Series to Scalar

The type hint can be expressed as pandas.Series, . . . -> Any.

By using pandas_udf() with the function having such type hints above, it creates a Pandas UDF similar to PyS-park’s aggregate functions. The given function takes pandas.Series and returns a scalar value. The return type shouldbe a primitive data type, and the returned scalar can be either a python primitive type, e.g., int or float or a numpydata type, e.g., numpy.int64 or numpy.float64. Any should ideally be a specific scalar type accordingly.

This UDF can be also used with GroupedData.agg() and Window. It defines an aggregation from one or morepandas.Series to a scalar value, where each pandas.Series represents a column within the group or window.

Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded intomemory. Also, only unbounded window is supported with Grouped aggregate Pandas UDFs currently. The followingexample shows how to use this type of UDF to compute mean with a group-by and window operations:

import pandas as pd

from pyspark.sql.functions import pandas_udffrom pyspark.sql import Window

df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v"))

# Declare the function and create the UDF@pandas_udf("double")def mean_udf(v: pd.Series) -> float:

return v.mean()

df.select(mean_udf(df['v'])).show()# +-----------+# |mean_udf(v)|# +-----------+# | 4.2|

(continues on next page)

2.1. Apache Arrow in PySpark 19

Page 24: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

# +-----------+

df.groupby("id").agg(mean_udf(df['v'])).show()# +---+-----------+# | id|mean_udf(v)|# +---+-----------+# | 1| 1.5|# | 2| 6.0|# +---+-----------+

w = Window \.partitionBy('id') \.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()# +---+----+------+# | id| v|mean_v|# +---+----+------+# | 1| 1.0| 1.5|# | 1| 2.0| 1.5|# | 2| 3.0| 6.0|# | 2| 5.0| 6.0|# | 2|10.0| 6.0|# +---+----+------+

For detailed usage, please see pandas_udf().

2.1.4 Pandas Function APIs

Pandas Function APIs can directly apply a Python native function against the whole DataFrame by using Pandasinstances. Internally it works similarly with Pandas UDFs by using Arrow to transfer data and Pandas to work with thedata, which allows vectorized operations. However, a Pandas Function API behaves as a regular API under PySparkDataFrame instead of Column, and Python type hints in Pandas Functions APIs are optional and do not affect howit works internally at this moment although they might be required in the future.

From Spark 3.0, grouped map pandas UDF is now categorized as a separate Pandas Function API,DataFrame.groupby().applyInPandas(). It is still possible to use it with pyspark.sql.functions.PandasUDFType and DataFrame.groupby().apply() as it was; however, it is pre-ferred to use DataFrame.groupby().applyInPandas() directly. Using pyspark.sql.functions.PandasUDFType will be deprecated in the future.

Grouped Map

Grouped map operations with Pandas instances are supported by DataFrame.groupby().applyInPandas()which requires a Python function that takes a pandas.DataFrame and return another pandas.DataFrame. Itmaps each group to each pandas.DataFrame in the Python function.

This API implements the “split-apply-combine” pattern which consists of three steps:

• Split the data into groups by using DataFrame.groupBy().

• Apply a function on each group. The input and output of the function are both pandas.DataFrame. Theinput data contains all the rows and columns for each group.

• Combine the results into a new PySpark DataFrame.

To use DataFrame.groupBy().applyInPandas(), the user needs to define the following:

20 Chapter 2. User Guide

Page 25: pyspark Documentation

pyspark Documentation, Release master

• A Python function that defines the computation for each group.

• A StructType object or a string that defines the schema of the output PySpark DataFrame.

The column labels of the returned pandas.DataFrame must either match the field names in the defined outputschema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. Seepandas.DataFrame on how to label columns when constructing a pandas.DataFrame.

Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memoryexceptions, especially if the group sizes are skewed. The configuration for maxRecordsPerBatch is not applied ongroups and it is up to the user to ensure that the grouped data will fit into the available memory.

The following example shows how to use DataFrame.groupby().applyInPandas() to subtract the meanfrom each value in the group.

df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v"))

def subtract_mean(pdf):# pdf is a pandas.DataFramev = pdf.vreturn pdf.assign(v=v - v.mean())

df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show()# +---+----+# | id| v|# +---+----+# | 1|-0.5|# | 1| 0.5|# | 2|-3.0|# | 2|-1.0|# | 2| 4.0|# +---+----+

For detailed usage, please see please see GroupedData.applyInPandas()

Map

Map operations with Pandas instances are supported by DataFrame.mapInPandas() which maps an itera-tor of pandas.DataFrames to another iterator of pandas.DataFrames that represents the current PySparkDataFrame and returns the result as a PySpark DataFrame. The function takes and outputs an iterator of pandas.DataFrame. It can return the output of arbitrary length in contrast to some Pandas UDFs although internally it workssimilarly with Series to Series Pandas UDF.

The following example shows how to use DataFrame.mapInPandas():

df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

def filter_func(iterator):for pdf in iterator:

yield pdf[pdf.id == 1]

df.mapInPandas(filter_func, schema=df.schema).show()# +---+---+# | id|age|# +---+---+

(continues on next page)

2.1. Apache Arrow in PySpark 21

Page 26: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

# | 1| 21|# +---+---+

For detailed usage, please see DataFrame.mapInPandas().

Co-grouped Map

Co-grouped map operations with Pandas instances are supported by DataFrame.groupby().cogroup().applyInPandas() which allows two PySpark DataFrames to be cogrouped by a common key and then a Pythonfunction applied to each cogroup. It consists of the following steps:

• Shuffle the data such that the groups of each dataframe which share a key are cogrouped together.

• Apply a function to each cogroup. The input of the function is two pandas.DataFrame (with an optionaltuple representing the key). The output of the function is a pandas.DataFrame.

• Combine the pandas.DataFrames from all groups into a new PySpark DataFrame.

To use groupBy().cogroup().applyInPandas(), the user needs to define the following:

• A Python function that defines the computation for each cogroup.

• A StructType object or a string that defines the schema of the output PySpark DataFrame.

The column labels of the returned pandas.DataFrame must either match the field names in the defined outputschema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. Seepandas.DataFrame. on how to label columns when constructing a pandas.DataFrame.

Note that all data for a cogroup will be loaded into memory before the function is applied. This can lead to outof memory exceptions, especially if the group sizes are skewed. The configuration for maxRecordsPerBatch is notapplied and it is up to the user to ensure that the cogrouped data will fit into the available memory.

The following example shows how to use DataFrame.groupby().cogroup().applyInPandas() to per-form an asof join between two datasets.

import pandas as pd

df1 = spark.createDataFrame([(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],("time", "id", "v1"))

df2 = spark.createDataFrame([(20000101, 1, "x"), (20000101, 2, "y")],("time", "id", "v2"))

def asof_join(l, r):return pd.merge_asof(l, r, on="time", by="id")

df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(asof_join, schema="time int, id int, v1 double, v2 string").show()

# +--------+---+---+---+# | time| id| v1| v2|# +--------+---+---+---+# |20000101| 1|1.0| x|# |20000102| 1|3.0| x|# |20000101| 2|2.0| y|# |20000102| 2|4.0| y|# +--------+---+---+---+

22 Chapter 2. User Guide

Page 27: pyspark Documentation

pyspark Documentation, Release master

For detailed usage, please see PandasCogroupedOps.applyInPandas()

2.1.5 Usage Notes

Supported SQL Types

Currently, all Spark SQL data types are supported by Arrow-based conversion except ArrayType ofTimestampType, and nested StructType. :class: MapType is only supported when using PyArrow 2.0.0 andabove.

Setting Arrow Batch Size

Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to high memory usagein the JVM. To avoid possible out of memory exceptions, the size of the Arrow record batches can be adjusted bysetting the conf spark.sql.execution.arrow.maxRecordsPerBatch to an integer that will determinethe maximum number of rows for each batch. The default value is 10,000 records per batch. If the number of columnsis large, the value should be adjusted accordingly. Using this limit, each data partition will be made into 1 or morerecord batches for processing.

Timestamp with Time Zone Semantics

Spark internally stores timestamps as UTC values, and timestamp data that is brought in without a specified timezone is converted as local time to UTC with microsecond resolution. When timestamp data is exported or displayedin Spark, the session time zone is used to localize the timestamp values. The session time zone is set with theconfiguration spark.sql.session.timeZone and will default to the JVM system local time zone if not set.Pandas uses a datetime64 type with nanosecond resolution, datetime64[ns], with optional time zone on aper-column basis.

When timestamp data is transferred from Spark to Pandas it will be converted to nanoseconds and each column will beconverted to the Spark session time zone then localized to that time zone, which removes the time zone and displaysvalues as local time. This will occur when calling DataFrame.toPandas() or pandas_udf with timestampcolumns.

When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This occurswhen calling SparkSession.createDataFrame() with a Pandas DataFrame or when returning a timestampfrom a pandas_udf. These conversions are done automatically to ensure Spark will have data in the expectedformat, so it is not necessary to do any of these conversions yourself. Any nanosecond values will be truncated.

Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is different thana Pandas timestamp. It is recommended to use Pandas time series functionality when working with timestamps inpandas_udfs to get the best performance, see here for details.

Recommended Pandas and PyArrow Versions

For usage with pyspark.sql, the minimum supported versions of Pandas is 0.23.2 and PyArrow is 1.0.0. Higher versionsmay be used, however, compatibility and data correctness can not be guaranteed and should be verified by the user.

2.1. Apache Arrow in PySpark 23

Page 28: pyspark Documentation

pyspark Documentation, Release master

Compatibility Setting for PyArrow >= 0.15.0 and Spark 2.3.x, 2.4.x

Since Arrow 0.15.0, a change in the binary IPC format requires an environment variable to be compatible with previousversions of Arrow <= 0.14.1. This is only necessary to do for PySpark users with versions 2.3.x and 2.4.x that havemanually upgraded PyArrow to 0.15.0. The following can be added to conf/spark-env.sh to use the legacyArrow IPC format:

ARROW_PRE_0_15_IPC_FORMAT=1

This will instruct PyArrow >= 0.15.0 to use the legacy IPC format with the older Arrow Java that is in Spark 2.3.x and2.4.x. Not setting this environment variable will lead to a similar error as described in SPARK-29367 when runningpandas_udfs or DataFrame.toPandas() with Arrow enabled. More information about the Arrow IPC changecan be read on the Arrow 0.15.0 release blog.

2.2 Python Package Management

When you want to run your PySpark application on a cluster such as YARN, Kubernetes, Mesos, etc., you need tomake sure that your code and all used libraries are available on the executors.

As an example let’s say you may want to run the Pandas UDF’s examples. As it uses pyarrow as an underlyingimplementation we need to make sure to have pyarrow installed on each executor on the cluster. Otherwise you mayget errors such as ModuleNotFoundError: No module named 'pyarrow'.

Here is the script app.py from the previous example that will be executed on the cluster:

import pandas as pdfrom pyspark.sql.functions import pandas_udffrom pyspark.sql import SparkSession

def main(spark):df = spark.createDataFrame(

[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v"))

@pandas_udf("double")def mean_udf(v: pd.Series) -> float:

return v.mean()

print(df.groupby("id").agg(mean_udf(df['v'])).collect())

if __name__ == "__main__":main(SparkSession.builder.getOrCreate())

There are multiple ways to manage Python dependencies in the cluster:

• Using PySpark Native Features

• Using Conda

• Using Virtualenv

• Using PEX

24 Chapter 2. User Guide

Page 29: pyspark Documentation

pyspark Documentation, Release master

2.2.1 Using PySpark Native Features

PySpark allows to upload Python files (.py), zipped Python packages (.zip), and Egg files (.egg) to the executorsby:

• Setting the configuration setting spark.submit.pyFiles

• Ssing --py-files option in Spark scripts

• Directly calling pyspark.SparkContext.addPyFile() in applications

This is a straightforward method to ship additional custom Python code to the cluster. You can just add individual filesor zip whole packages and upload them. Using pyspark.SparkContext.addPyFile() allows to upload codeeven after having started your job.

However, it does not allow to add packages built as Wheels and therefore does not allow to include dependencies withnative code.

2.2.2 Using Conda

Conda is one of the most widely-used Python package management systems. PySpark users can directly use a Condaenvironment to ship their third-party Python packages by leveraging conda-pack which is a command line tool creatingrelocatable Conda environments.

The example below creates a Conda environment to use on both the driver and executor and packs it into an archivefile. This archive file captures the Conda environment for Python and stores both Python interpreter and all its relevantdependencies.

conda create -y -n pyspark_conda_env -c conda-forge pyarrow pandas conda-packconda activate pyspark_conda_envconda pack -f -o pyspark_conda_env.tar.gz

After that, you can ship it together with scripts or in the code by using the --archives option or spark.archives configuration (spark.yarn.dist.archives in Yarn). It automatically unpacks the archive onexecutors.

In the case of a spark-submit script, you can use it as follows:

export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonspark-submit --archives pyspark_conda_env.tar.gz#environment app.py

Note that PYSPARK_DRIVER_PYTHON above is not required for cluster modes in Yarn or Kubernetes.

If you’re on a regular Python shell or notebook, you can try it as shown below:

import osfrom pyspark.sql import SparkSessionfrom app import main

os.environ['PYSPARK_PYTHON'] = "./environment/bin/python"spark = SparkSession.builder.config(

"spark.archives", # 'spark.yarn.dist.archives' in Yarn."pyspark_conda_env.tar.gz#environment").getOrCreate()

main(spark)

For a pyspark shell:

2.2. Python Package Management 25

Page 30: pyspark Documentation

pyspark Documentation, Release master

export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonpyspark --archives pyspark_conda_env.tar.gz#environment

2.2.3 Using Virtualenv

Virtualenv is a Python tool to create isolated Python environments. Since Python 3.3, a subset of its features hasbeen integrated into Python as a standard library under the venv module. PySpark users can use virtualenv to managePython dependencies in their clusters by using venv-pack in a similar way as conda-pack.

A virtual environment to use on both driver and executor can be created as demonstrated below. It packs the currentvirtual environment to an archive file, and It self-contains both Python interpreter and the dependencies.

python -m venv pyspark_venvsource pyspark_venv/bin/activatepip install pyarrow pandas venv-packvenv-pack -o pyspark_venv.tar.gz

You can directly pass/unpack the archive file and enable the environment on executors by leveraging the --archivesoption or spark.archives configuration (spark.yarn.dist.archives in Yarn).

For spark-submit, you can use it by running the command as follows. Also, notice thatPYSPARK_DRIVER_PYTHON is not necessary in Kubernetes or Yarn cluster modes.

export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonspark-submit --archives pyspark_venv.tar.gz#environment app.py

For regular Python shells or notebooks:

import osfrom pyspark.sql import SparkSessionfrom app import main

os.environ['PYSPARK_PYTHON'] = "./environment/bin/python"spark = SparkSession.builder.config(

"spark.archives", # 'spark.yarn.dist.archives' in Yarn."pyspark_venv.tar.gz#environment").getOrCreate()

main(spark)

In the case of a pyspark shell:

export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./environment/bin/pythonpyspark --archives pyspark_venv.tar.gz#environment

26 Chapter 2. User Guide

Page 31: pyspark Documentation

pyspark Documentation, Release master

2.2.4 Using PEX

PySpark can also use PEX to ship the Python packages together. PEX is a tool that creates a self-contained Pythonenvironment. This is similar to Conda or virtualenv, but a .pex file is executable by itself.

The following example creates a .pex file for the driver and executor to use. The file contains the Python dependenciesspecified with the pex command.

pip install pyarrow pandas pexpex pyspark pyarrow pandas -o pyspark_pex_env.pex

This file behaves similarly with a regular Python interpreter.

./pyspark_pex_env.pex -c "import pandas; print(pandas.__version__)"1.1.5

However, .pex file does not include a Python interpreter itself under the hood so all nodes in a cluster should havethe same Python interpreter installed.

In order to transfer and use the .pex file in a cluster, you should ship it via the spark.files configuration(spark.yarn.dist.files in Yarn) or --files option because they are regular files instead of directoriesor archive files.

For application submission, you run the commands as shown below. Note that PYSPARK_DRIVER_PYTHON is notneeded for cluster modes in Yarn or Kubernetes, and you may also need to set PYSPARK_PYTHON environment vari-able on the AppMaster --conf spark.yarn.appMasterEnv.PYSPARK_PYTHON=./myarchive.pex inYarn cluster mode.

export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./pyspark_pex_env.pexspark-submit --files pyspark_pex_env.pex app.py

For regular Python shells or notebooks:

import osfrom pyspark.sql import SparkSessionfrom app import main

os.environ['PYSPARK_PYTHON'] = "./pyspark_pex_env.pex"spark = SparkSession.builder.config(

"spark.files", # 'spark.yarn.dist.files' in Yarn."pyspark_pex_env.pex").getOrCreate()

main(spark)

For the interactive pyspark shell, the commands are almost the same:

export PYSPARK_DRIVER_PYTHON=pythonexport PYSPARK_PYTHON=./pyspark_pex_env.pexpyspark --files pyspark_pex_env.pex

An end-to-end Docker example for deploying a standalone PySpark with SparkSession.builder and PEX canbe found here - it uses cluster-pack, a library on top of PEX that automatizes the the intermediate step of having tocreate & upload the PEX manually.

2.2. Python Package Management 27

Page 32: pyspark Documentation

pyspark Documentation, Release master

28 Chapter 2. User Guide

Page 33: pyspark Documentation

CHAPTER

THREE

API REFERENCE

This page lists an overview of all public PySpark modules, classes, functions and methods.

3.1 Spark SQL

3.1.1 Core Classes

SparkSession(sparkContext[, jsparkSession]) The entry point to programming Spark with the Datasetand DataFrame API.

DataFrame(jdf, sql_ctx) A distributed collection of data grouped into namedcolumns.

Column(jc) A column in a DataFrame.Row(*args, **kwargs) A row in DataFrame.GroupedData(jgd, df) A set of methods for aggregations on a DataFrame,

created by DataFrame.groupBy().PandasCogroupedOps(gd1, gd2) A logical grouping of two GroupedData, created by

GroupedData.cogroup().DataFrameNaFunctions(df) Functionality for working with missing data in

DataFrame.DataFrameStatFunctions(df) Functionality for statistic functions with DataFrame.Window() Utility functions for defining window in DataFrames.

pyspark.sql.SparkSession

class pyspark.sql.SparkSession(sparkContext, jsparkSession=None)The entry point to programming Spark with the Dataset and DataFrame API.

A SparkSession can be used create DataFrame, register DataFrame as tables, execute SQL over tables,cache tables, and read parquet files. To create a SparkSession, use the following builder pattern:

builderA class attribute having a Builder to construct SparkSession instances.

29

Page 34: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> spark = SparkSession.builder \... .master("local") \... .appName("Word Count") \... .config("spark.some.config.option", "some-value") \... .getOrCreate()

>>> from datetime import datetime>>> from pyspark.sql import Row>>> spark = SparkSession(sc)>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),... time=datetime(2014, 8, 1, 14, 1, 5))])>>> df = allTypes.toDF()>>> df.createOrReplaceTempView("allTypes")>>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '... 'from allTypes where b and i > 0').collect()[Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False,→˓list[1]=2, dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5),→˓a=1)]>>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).→˓collect()[(1, 'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2,→˓3])]

Methods

createDataFrame(data[, schema, . . . ]) Creates a DataFrame from an RDD, a list or apandas.DataFrame.

getActiveSession() Returns the active SparkSession for the currentthread, returned by the builder

newSession() Returns a new SparkSession as new session, that hasseparate SQLConf, registered temporary views andUDFs, but shared SparkContext and table cache.

range(start[, end, step, numPartitions]) Create a DataFrame with single pyspark.sql.types.LongType column named id, containingelements in a range from start to end (exclusive)with step value step.

sql(sqlQuery) Returns a DataFrame representing the result of thegiven query.

stop() Stop the underlying SparkContext.table(tableName) Returns the specified table as a DataFrame.

30 Chapter 3. API Reference

Page 35: pyspark Documentation

pyspark Documentation, Release master

Attributes

builder A class attribute having a Builder to constructSparkSession instances.

catalog Interface through which the user may create, drop,alter or query underlying databases, tables, func-tions, etc.

conf Runtime configuration interface for Spark.read Returns a DataFrameReader that can be used to

read data in as a DataFrame.readStream Returns a DataStreamReader that can be used

to read data streams as a streaming DataFrame.sparkContext Returns the underlying SparkContext.streams Returns a StreamingQueryManager that al-

lows managing all the StreamingQuery in-stances active on this context.

udf Returns a UDFRegistration for UDF registra-tion.

version The version of Spark on which this application isrunning.

pyspark.sql.DataFrame

class pyspark.sql.DataFrame(jdf, sql_ctx)A distributed collection of data grouped into named columns.

A DataFrame is equivalent to a relational table in Spark SQL, and can be created using various functions inSparkSession:

people = spark.read.parquet("...")

Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in:DataFrame, Column.

To select a column from the DataFrame, use the apply method:

ageCol = people.age

A more concrete example:

# To create DataFrame using SparkSessionpeople = spark.read.parquet("...")department = spark.read.parquet("...")

people.filter(people.age > 30).join(department, people.deptId == department.id) \.groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})

New in version 1.3.0.

3.1. Spark SQL 31

Page 36: pyspark Documentation

pyspark Documentation, Release master

Methods

agg(*exprs) Aggregate on the entire DataFramewithout groups(shorthand for df.groupBy().agg()).

alias(alias) Returns a new DataFrame with an alias set.approxQuantile(col, probabilities, relativeEr-ror)

Calculates the approximate quantiles of numericalcolumns of a DataFrame.

cache() Persists the DataFrame with the default storagelevel (MEMORY_AND_DISK).

checkpoint([eager]) Returns a checkpointed version of this Dataset.coalesce(numPartitions) Returns a new DataFrame that has exactly

numPartitions partitions.colRegex(colName) Selects column based on the column name specified

as a regex and returns it as Column.collect() Returns all the records as a list of Row .corr(col1, col2[, method]) Calculates the correlation of two columns of a

DataFrame as a double value.count() Returns the number of rows in this DataFrame.cov(col1, col2) Calculate the sample covariance for the given

columns, specified by their names, as a double value.createGlobalTempView(name) Creates a global temporary view with this

DataFrame.createOrReplaceGlobalTempView(name) Creates or replaces a global temporary view using

the given name.createOrReplaceTempView(name) Creates or replaces a local temporary view with this

DataFrame.createTempView(name) Creates a local temporary view with this

DataFrame.crossJoin(other) Returns the cartesian product with another

DataFrame.crosstab(col1, col2) Computes a pair-wise frequency table of the given

columns.cube(*cols) Create a multi-dimensional cube for the current

DataFrame using the specified columns, so we canrun aggregations on them.

describe(*cols) Computes basic statistics for numeric and stringcolumns.

distinct() Returns a new DataFrame containing the distinctrows in this DataFrame.

drop(*cols) Returns a new DataFrame that drops the specifiedcolumn.

dropDuplicates([subset]) Return a new DataFrame with duplicate rows re-moved, optionally only considering certain columns.

drop_duplicates([subset]) drop_duplicates() is an alias fordropDuplicates().

dropna([how, thresh, subset]) Returns a new DataFrame omitting rows with nullvalues.

exceptAll(other) Return a new DataFrame containing rows in thisDataFrame but not in another DataFrame whilepreserving duplicates.

continues on next page

32 Chapter 3. API Reference

Page 37: pyspark Documentation

pyspark Documentation, Release master

Table 4 – continued from previous pageexplain([extended, mode]) Prints the (logical and physical) plans to the console

for debugging purpose.fillna(value[, subset]) Replace null values, alias for na.fill().filter(condition) Filters rows using the given condition.first() Returns the first row as a Row .foreach(f) Applies the f function to all Row of this

DataFrame.foreachPartition(f) Applies the f function to each partition of this

DataFrame.freqItems(cols[, support]) Finding frequent items for columns, possibly with

false positives.groupBy(*cols) Groups the DataFrame using the specified

columns, so we can run aggregation on them.groupby(*cols) groupby() is an alias for groupBy().head([n]) Returns the first n rows.hint(name, *parameters) Specifies some hint on the current DataFrame.inputFiles() Returns a best-effort snapshot of the files that com-

pose this DataFrame.intersect(other) Return a new DataFrame containing rows only in

both this DataFrame and another DataFrame.intersectAll(other) Return a new DataFrame containing rows in both

this DataFrame and another DataFrame whilepreserving duplicates.

isLocal() Returns True if the collect() and take()methods can be run locally (without any Spark ex-ecutors).

join(other[, on, how]) Joins with another DataFrame, using the given joinexpression.

limit(num) Limits the result count to the number specified.localCheckpoint([eager]) Returns a locally checkpointed version of this

Dataset.mapInPandas(func, schema) Maps an iterator of batches in the current

DataFrame using a Python native function thattakes and outputs a pandas DataFrame, and returnsthe result as a DataFrame.

orderBy(*cols, **kwargs) Returns a new DataFrame sorted by the specifiedcolumn(s).

persist([storageLevel]) Sets the storage level to persist the contents of theDataFrame across operations after the first time itis computed.

printSchema() Prints out the schema in the tree format.randomSplit(weights[, seed]) Randomly splits this DataFrame with the provided

weights.registerTempTable(name) Registers this DataFrame as a temporary table using

the given name.repartition(numPartitions, *cols) Returns a new DataFrame partitioned by the given

partitioning expressions.repartitionByRange(numPartitions, *cols) Returns a new DataFrame partitioned by the given

partitioning expressions.replace(to_replace[, value, subset]) Returns a new DataFrame replacing a value with

another value.continues on next page

3.1. Spark SQL 33

Page 38: pyspark Documentation

pyspark Documentation, Release master

Table 4 – continued from previous pagerollup(*cols) Create a multi-dimensional rollup for the current

DataFrame using the specified columns, so we canrun aggregation on them.

sameSemantics(other) Returns True when the logical query plans insideboth DataFrames are equal and therefore returnsame results.

sample([withReplacement, fraction, seed]) Returns a sampled subset of this DataFrame.sampleBy(col, fractions[, seed]) Returns a stratified sample without replacement

based on the fraction given on each stratum.select(*cols) Projects a set of expressions and returns a new

DataFrame.selectExpr(*expr) Projects a set of SQL expressions and returns a new

DataFrame.semanticHash() Returns a hash code of the logical query plan against

this DataFrame.show([n, truncate, vertical]) Prints the first n rows to the console.sort(*cols, **kwargs) Returns a new DataFrame sorted by the specified

column(s).sortWithinPartitions(*cols, **kwargs) Returns a new DataFrame with each partition

sorted by the specified column(s).subtract(other) Return a new DataFrame containing rows in this

DataFrame but not in another DataFrame.summary(*statistics) Computes specified statistics for numeric and string

columns.tail(num) Returns the last num rows as a list of Row .take(num) Returns the first num rows as a list of Row .toDF(*cols) Returns a new DataFrame that with new specified

column namestoJSON ([use_unicode]) Converts a DataFrame into a RDD of string.toLocalIterator([prefetchPartitions]) Returns an iterator that contains all of the rows in this

DataFrame.toPandas() Returns the contents of this DataFrame as Pandas

pandas.DataFrame.transform(func) Returns a new DataFrame.union(other) Return a new DataFrame containing union of rows

in this and another DataFrame.unionAll(other) Return a new DataFrame containing union of rows

in this and another DataFrame.unionByName(other[, allowMissingColumns]) Returns a new DataFrame containing union of

rows in this and another DataFrame.unpersist([blocking]) Marks the DataFrame as non-persistent, and re-

move all blocks for it from memory and disk.where(condition) where() is an alias for filter().withColumn(colName, col) Returns a new DataFrame by adding a column

or replacing the existing column that has the samename.

withColumnRenamed(existing, new) Returns a new DataFrame by renaming an existingcolumn.

withWatermark(eventTime, delayThreshold) Defines an event time watermark for thisDataFrame.

writeTo(table) Create a write configuration builder for v2 sources.

34 Chapter 3. API Reference

Page 39: pyspark Documentation

pyspark Documentation, Release master

Attributes

columns Returns all column names as a list.dtypes Returns all column names and their data types as a

list.isStreaming Returns True if this Dataset contains one or more

sources that continuously return data as it arrives.na Returns a DataFrameNaFunctions for han-

dling missing values.rdd Returns the content as an pyspark.RDD of Row .schema Returns the schema of this DataFrame as a

pyspark.sql.types.StructType.stat Returns a DataFrameStatFunctions for

statistic functions.storageLevel Get the DataFrame’s current storage level.write Interface for saving the content of the non-streaming

DataFrame out into external storage.writeStream Interface for saving the content of the streaming

DataFrame out into external storage.

pyspark.sql.Column

class pyspark.sql.Column(jc)A column in a DataFrame.

Column instances can be created by:

# 1. Select a column out of a DataFrame

df.colNamedf["colName"]

# 2. Create from an expressiondf.colName + 11 / df.colName

New in version 1.3.0.

Methods

alias(*alias, **kwargs) Returns this column aliased with a new name ornames (in the case of expressions that return morethan one column, such as explode).

asc() Returns a sort expression based on ascending orderof the column.

asc_nulls_first() Returns a sort expression based on ascending orderof the column, and null values return before non-nullvalues.

asc_nulls_last() Returns a sort expression based on ascending orderof the column, and null values appear after non-nullvalues.

continues on next page

3.1. Spark SQL 35

Page 40: pyspark Documentation

pyspark Documentation, Release master

Table 6 – continued from previous pageastype(dataType) astype() is an alias for cast().between(lowerBound, upperBound) A boolean expression that is evaluated to true if

the value of this expression is between the givencolumns.

bitwiseAND(other) Compute bitwise AND of this expression with an-other expression.

bitwiseOR(other) Compute bitwise OR of this expression with anotherexpression.

bitwiseXOR(other) Compute bitwise XOR of this expression with an-other expression.

cast(dataType) Convert the column into type dataType.contains(other) Contains the other element.desc() Returns a sort expression based on the descending

order of the column.desc_nulls_first() Returns a sort expression based on the descending

order of the column, and null values appear beforenon-null values.

desc_nulls_last() Returns a sort expression based on the descendingorder of the column, and null values appear afternon-null values.

dropFields(*fieldNames) An expression that drops fields in StructType byname.

endswith(other) String ends with.eqNullSafe(other) Equality test that is safe for null values.getField(name) An expression that gets a field by name in a Struct-

Field.getItem(key) An expression that gets an item at position

ordinal out of a list, or gets an item by key outof a dict.

isNotNull() True if the current expression is NOT null.isNull() True if the current expression is null.isin(*cols) A boolean expression that is evaluated to true if the

value of this expression is contained by the evaluatedvalues of the arguments.

like(other) SQL like expression.name(*alias, **kwargs) name() is an alias for alias().otherwise(value) Evaluates a list of conditions and returns one of mul-

tiple possible result expressions.over(window) Define a windowing column.rlike(other) SQL RLIKE expression (LIKE with Regex).startswith(other) String starts with.substr(startPos, length) Return a Column which is a substring of the col-

umn.when(condition, value) Evaluates a list of conditions and returns one of mul-

tiple possible result expressions.withField(fieldName, col) An expression that adds/replaces a field in

StructType by name.

36 Chapter 3. API Reference

Page 41: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.Row

class pyspark.sql.Row(*args, **kwargs)A row in DataFrame. The fields in it can be accessed:

• like attributes (row.key)

• like dictionary values (row[key])

key in row will search through row keys.

Row can be used to create a row object by using named arguments. It is not allowed to omit a named argumentto represent that the value is None or missing. This should be explicitly set to None in this case.

Changed in version 3.0.0: Rows created from named arguments no longer have field names sorted alphabeticallyand will be ordered in the position as entered.

Examples

>>> row = Row(name="Alice", age=11)>>> rowRow(name='Alice', age=11)>>> row['name'], row['age']('Alice', 11)>>> row.name, row.age('Alice', 11)>>> 'name' in rowTrue>>> 'wrong_key' in rowFalse

Row also can be used to create another Row like class, then it could be used to create Row objects, such as

>>> Person = Row("name", "age")>>> Person<Row('name', 'age')>>>> 'name' in PersonTrue>>> 'wrong_key' in PersonFalse>>> Person("Alice", 11)Row(name='Alice', age=11)

This form can also be used to create rows as tuple values, i.e. with unnamed fields.

>>> row1 = Row("Alice", 11)>>> row2 = Row(name="Alice", age=11)>>> row1 == row2True

3.1. Spark SQL 37

Page 42: pyspark Documentation

pyspark Documentation, Release master

Methods

asDict([recursive]) Return as a dictcount(value, /) Return number of occurrences of value.index(value[, start, stop]) Return first index of value.

pyspark.sql.GroupedData

class pyspark.sql.GroupedData(jgd, df)A set of methods for aggregations on a DataFrame, created by DataFrame.groupBy().

New in version 1.3.

Methods

agg(*exprs) Compute aggregates and returns the result as aDataFrame.

apply(udf) It is an alias of pyspark.sql.GroupedData.applyInPandas(); however, it takes apyspark.sql.functions.pandas_udf()whereas pyspark.sql.GroupedData.applyInPandas() takes a Python nativefunction.

applyInPandas(func, schema) Maps each group of the current DataFrame usinga pandas udf and returns the result as a DataFrame.

avg(*cols) Computes average values for each numeric columnsfor each group.

cogroup(other) Cogroups this group with another group so that wecan run cogrouped operations.

count() Counts the number of records for each group.max(*cols) Computes the max value for each numeric columns

for each group.mean(*cols) Computes average values for each numeric columns

for each group.min(*cols) Computes the min value for each numeric column for

each group.pivot(pivot_col[, values]) Pivots a column of the current DataFrame and per-

form the specified aggregation.sum(*cols) Compute the sum for each numeric columns for each

group.

38 Chapter 3. API Reference

Page 43: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.PandasCogroupedOps

class pyspark.sql.PandasCogroupedOps(gd1, gd2)A logical grouping of two GroupedData, created by GroupedData.cogroup().

New in version 3.0.0.

Notes

This API is experimental.

Methods

applyInPandas(func, schema) Applies a function to each cogroup using pandas andreturns the result as a DataFrame.

pyspark.sql.DataFrameNaFunctions

class pyspark.sql.DataFrameNaFunctions(df)Functionality for working with missing data in DataFrame.

New in version 1.4.

Methods

drop([how, thresh, subset]) Returns a new DataFrame omitting rows with nullvalues.

fill(value[, subset]) Replace null values, alias for na.fill().replace(to_replace[, value, subset]) Returns a new DataFrame replacing a value with

another value.

pyspark.sql.DataFrameStatFunctions

class pyspark.sql.DataFrameStatFunctions(df)Functionality for statistic functions with DataFrame.

New in version 1.4.

Methods

approxQuantile(col, probabilities, relativeEr-ror)

Calculates the approximate quantiles of numericalcolumns of a DataFrame.

corr(col1, col2[, method]) Calculates the correlation of two columns of aDataFrame as a double value.

cov(col1, col2) Calculate the sample covariance for the givencolumns, specified by their names, as a double value.

continues on next page

3.1. Spark SQL 39

Page 44: pyspark Documentation

pyspark Documentation, Release master

Table 11 – continued from previous pagecrosstab(col1, col2) Computes a pair-wise frequency table of the given

columns.freqItems(cols[, support]) Finding frequent items for columns, possibly with

false positives.sampleBy(col, fractions[, seed]) Returns a stratified sample without replacement

based on the fraction given on each stratum.

pyspark.sql.Window

class pyspark.sql.WindowUtility functions for defining window in DataFrames.

New in version 1.4.

Notes

When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, unboundedFol-lowing) is used by default. When ordering is defined, a growing window frame (rangeFrame, unboundedPre-ceding, currentRow) is used by default.

Examples

>>> # ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW>>> window = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.→˓currentRow)

>>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING>>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3)

Methods

orderBy(*cols) Creates a WindowSpec with the ordering defined.partitionBy(*cols) Creates a WindowSpec with the partitioning de-

fined.rangeBetween(start, end) Creates a WindowSpec with the frame boundaries

defined, from start (inclusive) to end (inclusive).rowsBetween(start, end) Creates a WindowSpec with the frame boundaries

defined, from start (inclusive) to end (inclusive).

40 Chapter 3. API Reference

Page 45: pyspark Documentation

pyspark Documentation, Release master

Attributes

currentRowunboundedFollowingunboundedPreceding

3.1.2 Spark Session APIs

The entry point to programming Spark with the Dataset and DataFrame API. To create a Spark session, you shoulduse SparkSession.builder attribute. See also SparkSession.

SparkSession.builder.appName(name) Sets a name for the application, which will be shown inthe Spark web UI.

SparkSession.builder.config([key, value,conf])

Sets a config option.

SparkSession.builder.enableHiveSupport()

Enables Hive support, including connectivity to a per-sistent Hive metastore, support for Hive SerDes, andHive user-defined functions.

SparkSession.builder.getOrCreate() Gets an existing SparkSession or, if there is no ex-isting one, creates a new one based on the options set inthis builder.

SparkSession.builder.master(master) Sets the Spark master URL to connect to, such as “lo-cal” to run locally, “local[4]” to run locally with 4 cores,or “spark://master:7077” to run on a Spark standalonecluster.

SparkSession.catalog Interface through which the user may create, drop, alteror query underlying databases, tables, functions, etc.

SparkSession.conf Runtime configuration interface for Spark.SparkSession.createDataFrame(data[,schema, . . . ])

Creates a DataFrame from an RDD, a list or apandas.DataFrame.

SparkSession.getActiveSession() Returns the active SparkSession for the current thread,returned by the builder

SparkSession.newSession() Returns a new SparkSession as new session, that hasseparate SQLConf, registered temporary views andUDFs, but shared SparkContext and table cache.

SparkSession.range(start[, end, step, . . . ]) Create a DataFrame with single pyspark.sql.types.LongType column named id, containing el-ements in a range from start to end (exclusive) withstep value step.

SparkSession.read Returns a DataFrameReader that can be used toread data in as a DataFrame.

SparkSession.readStream Returns a DataStreamReader that can be used toread data streams as a streaming DataFrame.

SparkSession.sparkContext Returns the underlying SparkContext.SparkSession.sql(sqlQuery) Returns a DataFrame representing the result of the

given query.SparkSession.stop() Stop the underlying SparkContext.SparkSession.streams Returns a StreamingQueryManager that allows

managing all the StreamingQuery instances activeon this context.

continues on next page

3.1. Spark SQL 41

Page 46: pyspark Documentation

pyspark Documentation, Release master

Table 14 – continued from previous pageSparkSession.table(tableName) Returns the specified table as a DataFrame.SparkSession.udf Returns a UDFRegistration for UDF registration.SparkSession.version The version of Spark on which this application is run-

ning.

pyspark.sql.SparkSession.builder.appName

SparkSession.builder.appName(name)Sets a name for the application, which will be shown in the Spark web UI.

If no application name is set, a randomly generated name will be used.

New in version 2.0.0.

Parameters

name [str] an application name

pyspark.sql.SparkSession.builder.config

SparkSession.builder.config(key=None, value=None, conf=None)Sets a config option. Options set using this method are automatically propagated to both SparkConf andSparkSession’s own configuration.

New in version 2.0.0.

Parameters

key [str, optional] a key name string for configuration property

value [str, optional] a value for configuration property

conf [SparkConf, optional] an instance of SparkConf

Examples

For an existing SparkConf, use conf parameter.

>>> from pyspark.conf import SparkConf>>> SparkSession.builder.config(conf=SparkConf())<pyspark.sql.session...

For a (key, value) pair, you can omit parameter names.

>>> SparkSession.builder.config("spark.some.config.option", "some-value")<pyspark.sql.session...

42 Chapter 3. API Reference

Page 47: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.SparkSession.builder.enableHiveSupport

SparkSession.builder.enableHiveSupport()Enables Hive support, including connectivity to a persistent Hive metastore, support for Hive SerDes, and Hiveuser-defined functions.

New in version 2.0.

pyspark.sql.SparkSession.builder.getOrCreate

SparkSession.builder.getOrCreate()Gets an existing SparkSession or, if there is no existing one, creates a new one based on the options set inthis builder.

New in version 2.0.0.

Examples

This method first checks whether there is a valid global default SparkSession, and if yes, return that one. If novalid global default SparkSession exists, the method creates a new SparkSession and assigns the newly createdSparkSession as the global default.

>>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate()>>> s1.conf.get("k1") == "v1"True

In case an existing SparkSession is returned, the config options specified in this builder will be applied to theexisting SparkSession.

>>> s2 = SparkSession.builder.config("k2", "v2").getOrCreate()>>> s1.conf.get("k1") == s2.conf.get("k1")True>>> s1.conf.get("k2") == s2.conf.get("k2")True

pyspark.sql.SparkSession.builder.master

SparkSession.builder.master(master)Sets the Spark master URL to connect to, such as “local” to run locally, “local[4]” to run locally with 4 cores,or “spark://master:7077” to run on a Spark standalone cluster.

New in version 2.0.0.

Parameters

master [str] a url for spark master

3.1. Spark SQL 43

Page 48: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.SparkSession.catalog

property SparkSession.catalogInterface through which the user may create, drop, alter or query underlying databases, tables, functions, etc.

New in version 2.0.0.

Returns

Catalog

pyspark.sql.SparkSession.conf

property SparkSession.confRuntime configuration interface for Spark.

This is the interface through which the user can get and set all Spark and Hadoop configurations that arerelevant to Spark SQL. When getting the value of a config, this defaults to the value set in the underlyingSparkContext, if any.

New in version 2.0.

pyspark.sql.SparkSession.createDataFrame

SparkSession.createDataFrame(data, schema=None, samplingRatio=None, verifySchema=True)Creates a DataFrame from an RDD, a list or a pandas.DataFrame.

When schema is a list of column names, the type of each column will be inferred from data.

When schema is None, it will try to infer the schema (column names and types) from data, which should bean RDD of either Row , namedtuple, or dict.

When schema is pyspark.sql.types.DataType or a datatype string, it must match the real data, oran exception will be thrown at runtime. If the given schema is not pyspark.sql.types.StructType,it will be wrapped into a pyspark.sql.types.StructType as its only field, and the field name will be“value”. Each record will also be wrapped into a tuple, which can be converted to row later.

If schema inference is needed, samplingRatio is used to determined the ratio of rows used for schemainference. The first row will be used if samplingRatio is None.

New in version 2.0.0.

Changed in version 2.1.0: Added verifySchema.

Parameters

data [RDD or iterable] an RDD of any kind of SQL data representation(e.g. Row , tuple, int,boolean, etc.), or list, or pandas.DataFrame.

schema [pyspark.sql.types.DataType, str or list, optional] a pyspark.sql.types.DataType or a datatype string or a list of column names, default isNone. The data type string format equals to pyspark.sql.types.DataType.simpleString, except that top level struct type can omit the struct<> and atomictypes use typeName() as their format, e.g. use byte instead of tinyint forpyspark.sql.types.ByteType. We can also use int as a short name forpyspark.sql.types.IntegerType.

samplingRatio [float, optional] the sample ratio of rows used for inferring

verifySchema [bool, optional] verify data types of every row against schema. Enabled by de-fault.

44 Chapter 3. API Reference

Page 49: pyspark Documentation

pyspark Documentation, Release master

Returns

DataFrame

Notes

Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.

Examples

>>> l = [('Alice', 1)]>>> spark.createDataFrame(l).collect()[Row(_1='Alice', _2=1)]>>> spark.createDataFrame(l, ['name', 'age']).collect()[Row(name='Alice', age=1)]

>>> d = [{'name': 'Alice', 'age': 1}]>>> spark.createDataFrame(d).collect()[Row(age=1, name='Alice')]

>>> rdd = sc.parallelize(l)>>> spark.createDataFrame(rdd).collect()[Row(_1='Alice', _2=1)]>>> df = spark.createDataFrame(rdd, ['name', 'age'])>>> df.collect()[Row(name='Alice', age=1)]

>>> from pyspark.sql import Row>>> Person = Row('name', 'age')>>> person = rdd.map(lambda r: Person(*r))>>> df2 = spark.createDataFrame(person)>>> df2.collect()[Row(name='Alice', age=1)]

>>> from pyspark.sql.types import *>>> schema = StructType([... StructField("name", StringType(), True),... StructField("age", IntegerType(), True)])>>> df3 = spark.createDataFrame(rdd, schema)>>> df3.collect()[Row(name='Alice', age=1)]

>>> spark.createDataFrame(df.toPandas()).collect()[Row(name='Alice', age=1)]>>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect()[Row(0=1, 1=2)]

>>> spark.createDataFrame(rdd, "a: string, b: int").collect()[Row(a='Alice', b=1)]>>> rdd = rdd.map(lambda row: row[1])>>> spark.createDataFrame(rdd, "int").collect()[Row(value=1)]>>> spark.createDataFrame(rdd, "boolean").collect()

(continues on next page)

3.1. Spark SQL 45

Page 50: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

Traceback (most recent call last):...

Py4JJavaError: ...

pyspark.sql.SparkSession.getActiveSession

classmethod SparkSession.getActiveSession()Returns the active SparkSession for the current thread, returned by the builder

New in version 3.0.0.

Returns

SparkSession Spark session if an active session exists for the current thread

Examples

>>> s = SparkSession.getActiveSession()>>> l = [('Alice', 1)]>>> rdd = s.sparkContext.parallelize(l)>>> df = s.createDataFrame(rdd, ['name', 'age'])>>> df.select("age").collect()[Row(age=1)]

pyspark.sql.SparkSession.newSession

SparkSession.newSession()Returns a new SparkSession as new session, that has separate SQLConf, registered temporary views and UDFs,but shared SparkContext and table cache.

New in version 2.0.

pyspark.sql.SparkSession.range

SparkSession.range(start, end=None, step=1, numPartitions=None)Create a DataFrame with single pyspark.sql.types.LongType column named id, containing ele-ments in a range from start to end (exclusive) with step value step.

New in version 2.0.0.

Parameters

start [int] the start value

end [int, optional] the end value (exclusive)

step [int, optional] the incremental step (default: 1)

numPartitions [int, optional] the number of partitions of the DataFrame

Returns

DataFrame

46 Chapter 3. API Reference

Page 51: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> spark.range(1, 7, 2).collect()[Row(id=1), Row(id=3), Row(id=5)]

If only one argument is specified, it will be used as the end value.

>>> spark.range(3).collect()[Row(id=0), Row(id=1), Row(id=2)]

pyspark.sql.SparkSession.read

property SparkSession.readReturns a DataFrameReader that can be used to read data in as a DataFrame.

New in version 2.0.0.

Returns

DataFrameReader

pyspark.sql.SparkSession.readStream

property SparkSession.readStreamReturns a DataStreamReader that can be used to read data streams as a streaming DataFrame.

New in version 2.0.0.

Returns

DataStreamReader

Notes

This API is evolving.

pyspark.sql.SparkSession.sparkContext

property SparkSession.sparkContextReturns the underlying SparkContext.

New in version 2.0.

pyspark.sql.SparkSession.sql

SparkSession.sql(sqlQuery)Returns a DataFrame representing the result of the given query.

New in version 2.0.0.

Returns

DataFrame

3.1. Spark SQL 47

Page 52: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.createOrReplaceTempView("table1")>>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")>>> df2.collect()[Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')]

pyspark.sql.SparkSession.stop

SparkSession.stop()Stop the underlying SparkContext.

New in version 2.0.

pyspark.sql.SparkSession.streams

property SparkSession.streamsReturns a StreamingQueryManager that allows managing all the StreamingQuery instances active onthis context.

New in version 2.0.0.

Returns

StreamingQueryManager

Notes

This API is evolving.

pyspark.sql.SparkSession.table

SparkSession.table(tableName)Returns the specified table as a DataFrame.

New in version 2.0.0.

Returns

DataFrame

Examples

>>> df.createOrReplaceTempView("table1")>>> df2 = spark.table("table1")>>> sorted(df.collect()) == sorted(df2.collect())True

48 Chapter 3. API Reference

Page 53: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.SparkSession.udf

property SparkSession.udfReturns a UDFRegistration for UDF registration.

New in version 2.0.0.

Returns

UDFRegistration

pyspark.sql.SparkSession.version

property SparkSession.versionThe version of Spark on which this application is running.

New in version 2.0.

3.1.3 Input and Output

DataFrameReader.csv(path[, schema, sep, . . . ]) Loads a CSV file and returns the result as aDataFrame.

DataFrameReader.format(source) Specifies the input data source format.DataFrameReader.jdbc(url, table[, column, . . . ]) Construct a DataFrame representing the database ta-

ble named table accessible via JDBC URL url andconnection properties.

DataFrameReader.json(path[, schema, . . . ]) Loads JSON files and returns the results as aDataFrame.

DataFrameReader.load([path, format, schema]) Loads data from a data source and returns it as aDataFrame.

DataFrameReader.option(key, value) Adds an input option for the underlying data source.DataFrameReader.options(**options) Adds input options for the underlying data source.DataFrameReader.orc(path[, mergeSchema, . . . ]) Loads ORC files, returning the result as a DataFrame.DataFrameReader.parquet(*paths, **options) Loads Parquet files, returning the result as a

DataFrame.DataFrameReader.schema(schema) Specifies the input schema.DataFrameReader.table(tableName) Returns the specified table as a DataFrame.DataFrameWriter.bucketBy(numBuckets, col,*cols)

Buckets the output by the given columns.If specified,the output is laid out on the file system similar to Hive’sbucketing scheme.

DataFrameWriter.csv(path[, mode, . . . ]) Saves the content of the DataFrame in CSV format atthe specified path.

DataFrameWriter.format(source) Specifies the underlying output data source.DataFrameWriter.insertInto(tableName[,. . . ])

Inserts the content of the DataFrame to the specifiedtable.

DataFrameWriter.jdbc(url, table[, mode, . . . ]) Saves the content of the DataFrame to an externaldatabase table via JDBC.

DataFrameWriter.json(path[, mode, . . . ]) Saves the content of the DataFrame in JSON format(JSON Lines text format or newline-delimited JSON) atthe specified path.

DataFrameWriter.mode(saveMode) Specifies the behavior when data or table already exists.DataFrameWriter.option(key, value) Adds an output option for the underlying data source.

continues on next page

3.1. Spark SQL 49

Page 54: pyspark Documentation

pyspark Documentation, Release master

Table 15 – continued from previous pageDataFrameWriter.options(**options) Adds output options for the underlying data source.DataFrameWriter.orc(path[, mode, . . . ]) Saves the content of the DataFrame in ORC format at

the specified path.DataFrameWriter.parquet(path[, mode, . . . ]) Saves the content of the DataFrame in Parquet format

at the specified path.DataFrameWriter.partitionBy(*cols) Partitions the output by the given columns on the file

system.DataFrameWriter.save([path, format, mode,. . . ])

Saves the contents of the DataFrame to a data source.

DataFrameWriter.saveAsTable(name[, for-mat, . . . ])

Saves the content of the DataFrame as the specifiedtable.

DataFrameWriter.sortBy(col, *cols) Sorts the output in each bucket by the given columns onthe file system.

DataFrameWriter.text(path[, compression, . . . ]) Saves the content of the DataFrame in a text file at thespecified path.

pyspark.sql.DataFrameReader.csv

DataFrameReader.csv(path, schema=None, sep=None, encoding=None, quote=None, escape=None,comment=None, header=None, inferSchema=None, ignoreLeadingWhiteS-pace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nan-Value=None, positiveInf=None, negativeInf=None, dateFormat=None,timestampFormat=None, maxColumns=None, maxCharsPerColumn=None,maxMalformedLogPerPartition=None, mode=None, columnNameOfCor-ruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None,lineSep=None, pathGlobFilter=None, recursiveFileLookup=None, modifiedBe-fore=None, modifiedAfter=None, unescapedQuoteHandling=None)

Loads a CSV file and returns the result as a DataFrame.

This function will go through the input once to determine the input schema if inferSchema is enabled. Toavoid going through the entire data once, disable inferSchema option or specify the schema explicitly usingschema.

New in version 2.0.0.

Parameters

path [str or list] string, or list of strings, for input path(s), or RDD of Strings storing CSV rows.

schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).

sep [str, optional] sets a separator (one or more characters) for each field and value. If None isset, it uses the default value, ,.

encoding [str, optional] decodes the CSV files by the given encoding type. If None is set, ituses the default value, UTF-8.

quote [str, optional] sets a single character used for escaping quoted values where the separatorcan be part of the value. If None is set, it uses the default value, ". If you would like to turnoff quotations, you need to set an empty string.

escape [str, optional] sets a single character used for escaping quotes inside an already quotedvalue. If None is set, it uses the default value, \.

50 Chapter 3. API Reference

Page 55: pyspark Documentation

pyspark Documentation, Release master

comment [str, optional] sets a single character used for skipping lines beginning with this char-acter. By default (None), it is disabled.

header [str or bool, optional] uses the first line as names of columns. If None is set, it uses thedefault value, false.

Note: if the given path is a RDD of Strings, this header option will remove all lines samewith the header if exists.

inferSchema [str or bool, optional] infers the input schema automatically from data. It requiresone extra pass over the data. If None is set, it uses the default value, false.

enforceSchema [str or bool, optional] If it is set to true, the specified or inferred schemawill be forcibly applied to datasource files, and headers in CSV files will be ignored. If theoption is set to false, the schema will be validated against all headers in CSV files or thefirst header in RDD if the header option is set to true. Field names in the schema andcolumn names in CSV headers are checked by their positions taking into account spark.sql.caseSensitive. If None is set, true is used by default. Though the default valueis true, it is recommended to disable the enforceSchema option to avoid incorrectresults.

ignoreLeadingWhiteSpace [str or bool, optional] A flag indicating whether or not leadingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.

ignoreTrailingWhiteSpace [str or bool, optional] A flag indicating whether or not trailingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.

nullValue [str, optional] sets the string representation of a null value. If None is set, it uses thedefault value, empty string. Since 2.0.1, this nullValue param applies to all supportedtypes including the string type.

nanValue [str, optional] sets the string representation of a non-number value. If None is set, ituses the default value, NaN.

positiveInf [str, optional] sets the string representation of a positive infinity value. If None isset, it uses the default value, Inf.

negativeInf [str, optional] sets the string representation of a negative infinity value. If None isset, it uses the default value, Inf.

dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.

timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].

maxColumns [str or int, optional] defines a hard limit of how many columns a record can have.If None is set, it uses the default value, 20480.

maxCharsPerColumn [str or int, optional] defines the maximum number of characters allowedfor any given value being read. If None is set, it uses the default value, -1meaning unlimitedlength.

maxMalformedLogPerPartition [str or int, optional] this parameter is no longer used sinceSpark 2.2.0. If specified, it is ignored.

3.1. Spark SQL 51

Page 56: pyspark Documentation

pyspark Documentation, Release master

mode [str, optional] allows a mode for dealing with corrupt records during parsing. If Noneis set, it uses the default value, PERMISSIVE. Note that Spark tries to parse only requiredcolumns in CSV under column pruning. Therefore, corrupt records can be different basedon required set of fields. This behavior can be controlled by spark.sql.csv.parser.columnPruning.enabled (enabled by default).

• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. A record with less/more tokensthan schema is not a corrupted record to CSV. When it meets a record having fewer to-kens than the length of the schema, sets null to extra fields. When the record has moretokens than the length of the schema, it drops extra tokens.

• DROPMALFORMED: ignores the whole corrupted records.

• FAILFAST: throws an exception when it meets corrupted records.

columnNameOfCorruptRecord [str, optional] allows renaming the new field havingmalformed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.

multiLine [str or bool, optional] parse records, which may span multiple lines. If None is set,it uses the default value, false.

charToEscapeQuoteEscaping [str, optional] sets a single character used for escaping the es-cape for the quote character. If None is set, the default value is escape character when escapeand quote characters are different, \0 otherwise.

samplingRatio [str or float, optional] defines fraction of rows used for schema inferring. IfNone is set, it uses the default value, 1.0.

emptyValue [str, optional] sets the string representation of an empty value. If None is set, ituses the default value, empty string.

locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.

lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \\r, \\r\\n and \\n. Maximum length is 1 character.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

modification times occurring before the specified time. The provided timestamp must be inthe following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedBefore (batch only) [an optional timestamp to only include files with] modificationtimes occurring before the specified time. The provided timestamp must be in the followingformat: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedAfter (batch only) [an optional timestamp to only include files with] modificationtimes occurring after the specified time. The provided timestamp must be in the follow-ing format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

52 Chapter 3. API Reference

Page 57: pyspark Documentation

pyspark Documentation, Release master

unescapedQuoteHandling [str, optional] defines how the CsvParser will handle values withunescaped quotes. If None is set, it uses the default value, STOP_AT_DELIMITER.

• STOP_AT_CLOSING_QUOTE: If unescaped quotes are found in the input, accumulatethe quote character and proceed parsing the value as a quoted value, until a closing quoteis found.

• BACK_TO_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters of the currentparsed value until the delimiter is found. If no delimiter is found in the value, the parserwill continue accumulating characters from the input until a delimiter or line ending isfound.

• STOP_AT_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters until the de-limiter or a line ending is found in the input.

• STOP_AT_DELIMITER: If unescaped quotes are found in the input, the content parsedfor the given value will be skipped and the value set in nullValue will be produced instead.

• RAISE_ERROR: If unescaped quotes are found in the input, a TextParsingException willbe thrown.

Examples

>>> df = spark.read.csv('python/test_support/sql/ages.csv')>>> df.dtypes[('_c0', 'string'), ('_c1', 'string')]>>> rdd = sc.textFile('python/test_support/sql/ages.csv')>>> df2 = spark.read.csv(rdd)>>> df2.dtypes[('_c0', 'string'), ('_c1', 'string')]

pyspark.sql.DataFrameReader.format

DataFrameReader.format(source)Specifies the input data source format.

New in version 1.4.0.

Parameters

source [str] string, name of the data source, e.g. ‘json’, ‘parquet’.

Examples

>>> df = spark.read.format('json').load('python/test_support/sql/people.json')>>> df.dtypes[('age', 'bigint'), ('name', 'string')]

3.1. Spark SQL 53

Page 58: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrameReader.jdbc

DataFrameReader.jdbc(url, table, column=None, lowerBound=None, upperBound=None, numParti-tions=None, predicates=None, properties=None)

Construct a DataFrame representing the database table named table accessible via JDBC URL url andconnection properties.

Partitions of the table will be retrieved in parallel if either column or predicates is specified.lowerBound, upperBound and numPartitions is needed when column is specified.

If both column and predicates are specified, column will be used.

New in version 1.4.0.

Parameters

url [str] a JDBC URL of the form jdbc:subprotocol:subname

table [str] the name of the table

column [str, optional] the name of a column of numeric, date, or timestamp type that will beused for partitioning; if this parameter is specified, then numPartitions, lowerBound(inclusive), and upperBound (exclusive) will form partition strides for generated WHEREclause expressions used to split the column column evenly

lowerBound [str or int, optional] the minimum value of column used to decide partition stride

upperBound [str or int, optional] the maximum value of column used to decide partition stride

numPartitions [int, optional] the number of partitions

predicates [list, optional] a list of expressions suitable for inclusion in WHERE clauses; eachone defines one partition of the DataFrame

properties [dict, optional] a dictionary of JDBC database connection arguments. Normallyat least properties “user” and “password” with their corresponding values. For example {‘user’ : ‘SYSTEM’, ‘password’ : ‘mypassword’ }

Returns

DataFrame

Notes

Don’t create too many partitions in parallel on a large cluster; otherwise Spark might crash your externaldatabase systems.

pyspark.sql.DataFrameReader.json

DataFrameReader.json(path, schema=None, primitivesAsString=None, prefersDecimal=None,allowComments=None, allowUnquotedFieldNames=None, allowSingle-Quotes=None, allowNumericLeadingZero=None, allowBackslashEscapin-gAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None,dateFormat=None, timestampFormat=None, multiLine=None, allowUnquot-edControlChars=None, lineSep=None, samplingRatio=None, dropFieldIfAll-Null=None, encoding=None, locale=None, pathGlobFilter=None, recursive-FileLookup=None, allowNonNumericNumbers=None, modifiedBefore=None,modifiedAfter=None)

Loads JSON files and returns the results as a DataFrame.

54 Chapter 3. API Reference

Page 59: pyspark Documentation

pyspark Documentation, Release master

JSON Lines (newline-delimited JSON) is supported by default. For JSON (one record per file), set themultiLine parameter to true.

If the schema parameter is not specified, this function goes through the input once to determine the inputschema.

New in version 1.4.0.

Parameters

path [str, list or RDD] string represents path to the JSON dataset, or a list of paths, or RDD ofStrings storing JSON objects.

schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).

primitivesAsString [str or bool, optional] infers all primitive values as a string type. If None isset, it uses the default value, false.

prefersDecimal [str or bool, optional] infers all floating-point values as a decimal type. If thevalues do not fit in decimal, then it infers them as doubles. If None is set, it uses the defaultvalue, false.

allowComments [str or bool, optional] ignores Java/C++ style comment in JSON records. IfNone is set, it uses the default value, false.

allowUnquotedFieldNames [str or bool, optional] allows unquoted JSON field names. If Noneis set, it uses the default value, false.

allowSingleQuotes [str or bool, optional] allows single quotes in addition to double quotes. IfNone is set, it uses the default value, true.

allowNumericLeadingZero [str or bool, optional] allows leading zeros in numbers (e.g.00012). If None is set, it uses the default value, false.

allowBackslashEscapingAnyCharacter [str or bool, optional] allows accepting quoting of allcharacter using backslash quoting mechanism. If None is set, it uses the default value,false.

mode [str, optional]

allows a mode for dealing with corrupt records during parsing. If None is set, it usesthe default value, PERMISSIVE.

• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. When inferring a schema, it im-plicitly adds a columnNameOfCorruptRecord field in an output schema.

• DROPMALFORMED: ignores the whole corrupted records.

• FAILFAST: throws an exception when it meets corrupted records.

columnNameOfCorruptRecord: str, optional allows renaming the new field having mal-formed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.

3.1. Spark SQL 55

Page 60: pyspark Documentation

pyspark Documentation, Release master

dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.

timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].

multiLine [str or bool, optional] parse one record, which may span multiple lines, per file. IfNone is set, it uses the default value, false.

allowUnquotedControlChars [str or bool, optional] allows JSON Strings to contain unquotedcontrol characters (ASCII characters with value less than 32, including tab and line feedcharacters) or not.

encoding [str or bool, optional] allows to forcibly set one of standard basic or extended encod-ing for the JSON files. For example UTF-16BE, UTF-32LE. If None is set, the encoding ofinput JSON will be detected automatically when the multiLine option is set to true.

lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \r, \r\n and \n.

samplingRatio [str or float, optional] defines fraction of input JSON objects used for schemainferring. If None is set, it uses the default value, 1.0.

dropFieldIfAllNull [str or bool, optional] whether to ignore column of all null values or emptyarray/struct during schema inference. If None is set, it uses the default value, false.

locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

allowNonNumericNumbers [str or bool] allows JSON parser to recognize set of “Not-a-Number” (NaN) tokens as legal floating number values. If None is set, it uses the defaultvalue, true.

• +INF: for positive infinity, as well as alias of +Infinity and Infinity.

• -INF: for negative infinity, alias -Infinity.

• NaN: for other not-a-numbers, like result of division by zero.

modifiedBefore [an optional timestamp to only include files with] modification times occurringbefore the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedAfter [an optional timestamp to only include files with] modification times occurringafter the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

56 Chapter 3. API Reference

Page 61: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df1 = spark.read.json('python/test_support/sql/people.json')>>> df1.dtypes[('age', 'bigint'), ('name', 'string')]>>> rdd = sc.textFile('python/test_support/sql/people.json')>>> df2 = spark.read.json(rdd)>>> df2.dtypes[('age', 'bigint'), ('name', 'string')]

pyspark.sql.DataFrameReader.load

DataFrameReader.load(path=None, format=None, schema=None, **options)Loads data from a data source and returns it as a DataFrame.

New in version 1.4.0.

Parameters

path [str or list, optional] optional string or a list of string for file-system backed data sources.

format [str, optional] optional string for format of the data source. Default to ‘parquet’.

schema [pyspark.sql.types.StructType or str, optional] optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For example col0INT, col1 DOUBLE).

**options [dict] all other string options

Examples

>>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_→˓partitioned',... opt1=True, opt2=1, opt3='str')>>> df.dtypes[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]

>>> df = spark.read.format('json').load(['python/test_support/sql/people.json',... 'python/test_support/sql/people1.json'])>>> df.dtypes[('age', 'bigint'), ('aka', 'string'), ('name', 'string')]

pyspark.sql.DataFrameReader.option

DataFrameReader.option(key, value)Adds an input option for the underlying data source.

You can set the following option(s) for reading files:

• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

3.1. Spark SQL 57

Page 62: pyspark Documentation

pyspark Documentation, Release master

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

• pathGlobFilter: an optional glob pattern to only include files with paths matching the pat-tern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior ofpartition discovery.

• modifiedBefore: an optional timestamp to only include files with modification times occur-ring before the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

• modifiedAfter: an optional timestamp to only include files with modification times occur-ring after the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

New in version 1.5.

pyspark.sql.DataFrameReader.options

DataFrameReader.options(**options)Adds input options for the underlying data source.

You can set the following option(s) for reading files:

• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

• pathGlobFilter: an optional glob pattern to only include files with paths matching the pat-tern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change the behavior ofpartition discovery.

• modifiedBefore: an optional timestamp to only include files with modification times occur-ring before the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

• modifiedAfter: an optional timestamp to only include files with modification times occur-ring after the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

New in version 1.4.

58 Chapter 3. API Reference

Page 63: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrameReader.orc

DataFrameReader.orc(path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None,modifiedBefore=None, modifiedAfter=None)

Loads ORC files, returning the result as a DataFrame.

New in version 1.5.0.

Parameters

path [str or list]

mergeSchema [str or bool, optional] sets whether we should merge schemas collected from allORC part-files. This will override spark.sql.orc.mergeSchema. The default valueis specified in spark.sql.orc.mergeSchema.

pathGlobFilter [str or bool] an optional glob pattern to only include files with paths matchingthe pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does not change thebehavior of partition discovery. # noqa

recursiveFileLookup [str or bool] recursively scan a directory for files. Using this option dis-ables partition discovery. # noqa

modification times occurring before the specified time. The provided timestamp must be inthe following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedBefore [an optional timestamp to only include files with] modification times occurringbefore the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedAfter [an optional timestamp to only include files with] modification times occurringafter the specified time. The provided timestamp must be in the following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

Examples

>>> df = spark.read.orc('python/test_support/sql/orc_partitioned')>>> df.dtypes[('a', 'bigint'), ('b', 'int'), ('c', 'int')]

pyspark.sql.DataFrameReader.parquet

DataFrameReader.parquet(*paths, **options)Loads Parquet files, returning the result as a DataFrame.

New in version 1.4.0.

Parameters

paths [str]

Other Parameters

mergeSchema [str or bool, optional] sets whether we should merge schemas collected fromall Parquet part-files. This will override spark.sql.parquet.mergeSchema. Thedefault value is specified in spark.sql.parquet.mergeSchema.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa

3.1. Spark SQL 59

Page 64: pyspark Documentation

pyspark Documentation, Release master

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

modification times occurring before the specified time. The provided timestamp must be inthe following format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedBefore (batch only) [an optional timestamp to only include files with] modificationtimes occurring before the specified time. The provided timestamp must be in the followingformat: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

modifiedAfter (batch only) [an optional timestamp to only include files with] modificationtimes occurring after the specified time. The provided timestamp must be in the follow-ing format: YYYY-MM-DDTHH:mm:ss (e.g. 2020-06-01T13:00:00)

Examples

>>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')>>> df.dtypes[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]

pyspark.sql.DataFrameReader.schema

DataFrameReader.schema(schema)Specifies the input schema.

Some data sources (e.g. JSON) can infer the input schema automatically from data. By specifying the schemahere, the underlying data source can skip the schema inference step, and thus speed up data loading.

New in version 1.4.0.

Parameters

schema [pyspark.sql.types.StructType or str] a pyspark.sql.types.StructType object or a DDL-formatted string (For example col0 INT, col1DOUBLE).

>>> s = spark.read.schema(“col0 INT, col1 DOUBLE”)

pyspark.sql.DataFrameReader.table

DataFrameReader.table(tableName)Returns the specified table as a DataFrame.

New in version 1.4.0.

Parameters

tableName [str] string, name of the table.

60 Chapter 3. API Reference

Page 65: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')>>> df.createOrReplaceTempView('tmpTable')>>> spark.read.table('tmpTable').dtypes[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]

pyspark.sql.DataFrameWriter.bucketBy

DataFrameWriter.bucketBy(numBuckets, col, *cols)Buckets the output by the given columns.If specified, the output is laid out on the file system similar to Hive’sbucketing scheme.

New in version 2.3.0.

Parameters

numBuckets [int] the number of buckets to save

col [str, list or tuple] a name of a column, or a list of names.

cols [str] additional names (optional). If col is a list it should be empty.

Notes

Applicable for file-based data sources in combination with DataFrameWriter.saveAsTable().

Examples

>>> (df.write.format('parquet')... .bucketBy(100, 'year', 'month')... .mode("overwrite")... .saveAsTable('bucketed_table'))

pyspark.sql.DataFrameWriter.csv

DataFrameWriter.csv(path, mode=None, compression=None, sep=None, quote=None, escape=None,header=None, nullValue=None, escapeQuotes=None, quoteAll=None, date-Format=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ig-noreTrailingWhiteSpace=None, charToEscapeQuoteEscaping=None, encod-ing=None, emptyValue=None, lineSep=None)

Saves the content of the DataFrame in CSV format at the specified path.

New in version 2.0.0.

Parameters

path [str] the path in any Hadoop supported file system

mode [str, optional] specifies the behavior of the save operation when data already exists.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• ignore: Silently ignore this operation if data already exists.

3.1. Spark SQL 61

Page 66: pyspark Documentation

pyspark Documentation, Release master

• error or errorifexists (default case): Throw an exception if data alreadyexists.

compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate).

sep [str, optional] sets a separator (one or more characters) for each field and value. If None isset, it uses the default value, ,.

quote [str, optional] sets a single character used for escaping quoted values where the separatorcan be part of the value. If None is set, it uses the default value, ". If an empty string is set,it uses u0000 (null character).

escape [str, optional] sets a single character used for escaping quotes inside an already quotedvalue. If None is set, it uses the default value, \

escapeQuotes [str or bool, optional] a flag indicating whether values containing quotes shouldalways be enclosed in quotes. If None is set, it uses the default value true, escaping allvalues containing a quote character.

quoteAll [str or bool, optional] a flag indicating whether all values should always be enclosedin quotes. If None is set, it uses the default value false, only escaping values containing aquote character.

header [str or bool, optional] writes the names of columns as the first line. If None is set, ituses the default value, false.

nullValue [str, optional] sets the string representation of a null value. If None is set, it uses thedefault value, empty string.

dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.

timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].

ignoreLeadingWhiteSpace [str or bool, optional] a flag indicating whether or not leadingwhitespaces from values being written should be skipped. If None is set, it uses the de-fault value, true.

ignoreTrailingWhiteSpace [str or bool, optional] a flag indicating whether or not trailingwhitespaces from values being written should be skipped. If None is set, it uses the de-fault value, true.

charToEscapeQuoteEscaping [str, optional] sets a single character used for escaping the es-cape for the quote character. If None is set, the default value is escape character when escapeand quote characters are different, \0 otherwise..

encoding [str, optional] sets the encoding (charset) of saved csv files. If None is set, the defaultUTF-8 charset will be used.

emptyValue [str, optional] sets the string representation of an empty value. If None is set, ituses the default value, "".

lineSep [str, optional] defines the line separator that should be used for writing. If None is set,it uses the default value, \\n. Maximum length is 1 character.

62 Chapter 3. API Reference

Page 67: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))

pyspark.sql.DataFrameWriter.format

DataFrameWriter.format(source)Specifies the underlying output data source.

New in version 1.4.0.

Parameters

source [str] string, name of the data source, e.g. ‘json’, ‘parquet’.

Examples

>>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data'))

pyspark.sql.DataFrameWriter.insertInto

DataFrameWriter.insertInto(tableName, overwrite=None)Inserts the content of the DataFrame to the specified table.

It requires that the schema of the DataFrame is the same as the schema of the table.

Optionally overwriting any existing data.

New in version 1.4.

pyspark.sql.DataFrameWriter.jdbc

DataFrameWriter.jdbc(url, table, mode=None, properties=None)Saves the content of the DataFrame to an external database table via JDBC.

New in version 1.4.0.

Parameters

url [str] a JDBC URL of the form jdbc:subprotocol:subname

table [str] Name of the table in the external database.

mode [str, optional] specifies the behavior of the save operation when data already exists.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• ignore: Silently ignore this operation if data already exists.

• error or errorifexists (default case): Throw an exception if data already exists.

properties [dict] a dictionary of JDBC database connection arguments. Normally at least prop-erties “user” and “password” with their corresponding values. For example { ‘user’ : ‘SYS-TEM’, ‘password’ : ‘mypassword’ }

3.1. Spark SQL 63

Page 68: pyspark Documentation

pyspark Documentation, Release master

Notes

Don’t create too many partitions in parallel on a large cluster; otherwise Spark might crash your externaldatabase systems.

pyspark.sql.DataFrameWriter.json

DataFrameWriter.json(path, mode=None, compression=None, dateFormat=None, timestampFor-mat=None, lineSep=None, encoding=None, ignoreNullFields=None)

Saves the content of the DataFrame in JSON format (JSON Lines text format or newline-delimited JSON) atthe specified path.

New in version 1.4.0.

Parameters

path [str] the path in any Hadoop supported file system

mode [str, optional] specifies the behavior of the save operation when data already exists.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• ignore: Silently ignore this operation if data already exists.

• error or errorifexists (default case): Throw an exception if data already exists.

compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate).

dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.

timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].

encoding [str, optional] specifies encoding (charset) of saved json files. If None is set, thedefault UTF-8 charset will be used.

lineSep [str, optional defines the line separator that should be used for writing. If None is] set,it uses the default value, \n.

ignoreNullFields [str or bool, optional] Whether to ignore null fields when generating JSONobjects. If None is set, it uses the default value, true.

Examples

>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))

64 Chapter 3. API Reference

Page 69: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrameWriter.mode

DataFrameWriter.mode(saveMode)Specifies the behavior when data or table already exists.

Options include:

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• error or errorifexists: Throw an exception if data already exists.

• ignore: Silently ignore this operation if data already exists.

New in version 1.4.0.

Examples

>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))

pyspark.sql.DataFrameWriter.option

DataFrameWriter.option(key, value)Adds an output option for the underlying data source.

You can set the following option(s) for writing files:

• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

New in version 1.5.

pyspark.sql.DataFrameWriter.options

DataFrameWriter.options(**options)Adds output options for the underlying data source.

You can set the following option(s) for writing files:

• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

3.1. Spark SQL 65

Page 70: pyspark Documentation

pyspark Documentation, Release master

New in version 1.4.

pyspark.sql.DataFrameWriter.orc

DataFrameWriter.orc(path, mode=None, partitionBy=None, compression=None)Saves the content of the DataFrame in ORC format at the specified path.

New in version 1.5.0.

Parameters

path [str] the path in any Hadoop supported file system

mode [str, optional] specifies the behavior of the save operation when data already exists.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• ignore: Silently ignore this operation if data already exists.

• error or errorifexists (default case): Throw an exception if data already exists.

partitionBy [str or list, optional] names of partitioning columns

compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, snappy, zlib, and lzo). This will overrideorc.compress and spark.sql.orc.compression.codec. If None is set, it usesthe value specified in spark.sql.orc.compression.codec.

Examples

>>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned')>>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))

pyspark.sql.DataFrameWriter.parquet

DataFrameWriter.parquet(path, mode=None, partitionBy=None, compression=None)Saves the content of the DataFrame in Parquet format at the specified path.

New in version 1.4.0.

Parameters

path [str] the path in any Hadoop supported file system

mode [str, optional] specifies the behavior of the save operation when data already exists.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• ignore: Silently ignore this operation if data already exists.

• error or errorifexists (default case): Throw an exception if data already exists.

partitionBy [str or list, optional] names of partitioning columns

66 Chapter 3. API Reference

Page 71: pyspark Documentation

pyspark Documentation, Release master

compression [str, optional] compression codec to use when saving to file. This can be one of theknown case-insensitive shorten names (none, uncompressed, snappy, gzip, lzo, brotli, lz4,and zstd). This will override spark.sql.parquet.compression.codec. If Noneis set, it uses the value specified in spark.sql.parquet.compression.codec.

Examples

>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))

pyspark.sql.DataFrameWriter.partitionBy

DataFrameWriter.partitionBy(*cols)Partitions the output by the given columns on the file system.

If specified, the output is laid out on the file system similar to Hive’s partitioning scheme.

New in version 1.4.0.

Parameters

cols [str or list] name of columns

Examples

>>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(),→˓ 'data'))

pyspark.sql.DataFrameWriter.save

DataFrameWriter.save(path=None, format=None, mode=None, partitionBy=None, **options)Saves the contents of the DataFrame to a data source.

The data source is specified by the format and a set of options. If format is not specified, the default datasource configured by spark.sql.sources.default will be used.

New in version 1.4.0.

Parameters

path [str, optional] the path in a Hadoop supported file system

format [str, optional] the format used to save

mode [str, optional] specifies the behavior of the save operation when data already exists.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• ignore: Silently ignore this operation if data already exists.

• error or errorifexists (default case): Throw an exception if data already exists.

partitionBy [list, optional] names of partitioning columns

**options [dict] all other string options

3.1. Spark SQL 67

Page 72: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.write.mode("append").save(os.path.join(tempfile.mkdtemp(), 'data'))

pyspark.sql.DataFrameWriter.saveAsTable

DataFrameWriter.saveAsTable(name, format=None, mode=None, partitionBy=None, **options)Saves the content of the DataFrame as the specified table.

In the case the table already exists, behavior of this function depends on the save mode, specified by the modefunction (default to throwing an exception). When mode is Overwrite, the schema of the DataFrame does notneed to be the same as that of the existing table.

• append: Append contents of this DataFrame to existing data.

• overwrite: Overwrite existing data.

• error or errorifexists: Throw an exception if data already exists.

• ignore: Silently ignore this operation if data already exists.

New in version 1.4.0.

Parameters

name [str] the table name

format [str, optional] the format used to save

mode [str, optional] one of append, overwrite, error, errorifexists, ignore (default: error)

partitionBy [str or list] names of partitioning columns

**options [dict] all other string options

pyspark.sql.DataFrameWriter.sortBy

DataFrameWriter.sortBy(col, *cols)Sorts the output in each bucket by the given columns on the file system.

New in version 2.3.0.

Parameters

col [str, tuple or list] a name of a column, or a list of names.

cols [str] additional names (optional). If col is a list it should be empty.

Examples

>>> (df.write.format('parquet')... .bucketBy(100, 'year', 'month')... .sortBy('day')... .mode("overwrite")... .saveAsTable('sorted_bucketed_table'))

68 Chapter 3. API Reference

Page 73: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrameWriter.text

DataFrameWriter.text(path, compression=None, lineSep=None)Saves the content of the DataFrame in a text file at the specified path. The text files will be encoded as UTF-8.

New in version 1.6.0.

Parameters

path [str] the path in any Hadoop supported file system

compression [str, optional] compression codec to use when saving to file. This can be one ofthe known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate).

lineSep [str, optional] defines the line separator that should be used for writing. If None is set,it uses the default value, \n.

The DataFrame must have only one column that is of string type.

Each row becomes a new line in the output file.

3.1.4 DataFrame APIs

DataFrame.agg(*exprs) Aggregate on the entire DataFrame without groups(shorthand for df.groupBy().agg()).

DataFrame.alias(alias) Returns a new DataFrame with an alias set.DataFrame.approxQuantile(col, probabilities,. . . )

Calculates the approximate quantiles of numericalcolumns of a DataFrame.

DataFrame.cache() Persists the DataFrame with the default storage level(MEMORY_AND_DISK).

DataFrame.checkpoint([eager]) Returns a checkpointed version of this Dataset.DataFrame.coalesce(numPartitions) Returns a new DataFrame that has exactly numParti-

tions partitions.DataFrame.colRegex(colName) Selects column based on the column name specified as

a regex and returns it as Column.DataFrame.collect() Returns all the records as a list of Row .DataFrame.columns Returns all column names as a list.DataFrame.corr(col1, col2[, method]) Calculates the correlation of two columns of a

DataFrame as a double value.DataFrame.count() Returns the number of rows in this DataFrame.DataFrame.cov(col1, col2) Calculate the sample covariance for the given columns,

specified by their names, as a double value.DataFrame.createGlobalTempView(name) Creates a global temporary view with this DataFrame.DataFrame.createOrReplaceGlobalTempView(name)Creates or replaces a global temporary view using the

given name.DataFrame.createOrReplaceTempView(name) Creates or replaces a local temporary view with this

DataFrame.DataFrame.createTempView(name) Creates a local temporary view with this DataFrame.DataFrame.crossJoin(other) Returns the cartesian product with another

DataFrame.DataFrame.crosstab(col1, col2) Computes a pair-wise frequency table of the given

columns.continues on next page

3.1. Spark SQL 69

Page 74: pyspark Documentation

pyspark Documentation, Release master

Table 16 – continued from previous pageDataFrame.cube(*cols) Create a multi-dimensional cube for the current

DataFrame using the specified columns, so we canrun aggregations on them.

DataFrame.describe(*cols) Computes basic statistics for numeric and stringcolumns.

DataFrame.distinct() Returns a new DataFrame containing the distinctrows in this DataFrame.

DataFrame.drop(*cols) Returns a new DataFrame that drops the specified col-umn.

DataFrame.dropDuplicates([subset]) Return a new DataFrame with duplicate rows re-moved, optionally only considering certain columns.

DataFrame.drop_duplicates([subset]) drop_duplicates() is an alias fordropDuplicates().

DataFrame.dropna([how, thresh, subset]) Returns a new DataFrame omitting rows with nullvalues.

DataFrame.dtypes Returns all column names and their data types as a list.DataFrame.exceptAll(other) Return a new DataFrame containing rows in this

DataFrame but not in another DataFrame whilepreserving duplicates.

DataFrame.explain([extended, mode]) Prints the (logical and physical) plans to the console fordebugging purpose.

DataFrame.fillna(value[, subset]) Replace null values, alias for na.fill().DataFrame.filter(condition) Filters rows using the given condition.DataFrame.first() Returns the first row as a Row .DataFrame.foreach(f) Applies the f function to all Row of this DataFrame.DataFrame.foreachPartition(f) Applies the f function to each partition of this

DataFrame.DataFrame.freqItems(cols[, support]) Finding frequent items for columns, possibly with false

positives.DataFrame.groupBy(*cols) Groups the DataFrame using the specified columns,

so we can run aggregation on them.DataFrame.head([n]) Returns the first n rows.DataFrame.hint(name, *parameters) Specifies some hint on the current DataFrame.DataFrame.inputFiles() Returns a best-effort snapshot of the files that compose

this DataFrame.DataFrame.intersect(other) Return a new DataFrame containing rows only in both

this DataFrame and another DataFrame.DataFrame.intersectAll(other) Return a new DataFrame containing rows in both this

DataFrame and another DataFrame while preserv-ing duplicates.

DataFrame.isLocal() Returns True if the collect() and take() meth-ods can be run locally (without any Spark executors).

DataFrame.isStreaming Returns True if this Dataset contains one or moresources that continuously return data as it arrives.

DataFrame.join(other[, on, how]) Joins with another DataFrame, using the given joinexpression.

DataFrame.limit(num) Limits the result count to the number specified.DataFrame.localCheckpoint([eager]) Returns a locally checkpointed version of this Dataset.

continues on next page

70 Chapter 3. API Reference

Page 75: pyspark Documentation

pyspark Documentation, Release master

Table 16 – continued from previous pageDataFrame.mapInPandas(func, schema) Maps an iterator of batches in the current DataFrame

using a Python native function that takes and out-puts a pandas DataFrame, and returns the result as aDataFrame.

DataFrame.na Returns a DataFrameNaFunctions for handlingmissing values.

DataFrame.orderBy(*cols, **kwargs) Returns a new DataFrame sorted by the specified col-umn(s).

DataFrame.persist([storageLevel]) Sets the storage level to persist the contents of theDataFrame across operations after the first time it iscomputed.

DataFrame.printSchema() Prints out the schema in the tree format.DataFrame.randomSplit(weights[, seed]) Randomly splits this DataFrame with the provided

weights.DataFrame.rdd Returns the content as an pyspark.RDD of Row .DataFrame.registerTempTable(name) Registers this DataFrame as a temporary table using the

given name.DataFrame.repartition(numPartitions, *cols) Returns a new DataFrame partitioned by the given

partitioning expressions.DataFrame.repartitionByRange(numPartitions,. . . )

Returns a new DataFrame partitioned by the givenpartitioning expressions.

DataFrame.replace(to_replace[, value, subset]) Returns a new DataFrame replacing a value with an-other value.

DataFrame.rollup(*cols) Create a multi-dimensional rollup for the currentDataFrame using the specified columns, so we canrun aggregation on them.

DataFrame.sameSemantics(other) Returns True when the logical query plans inside bothDataFrames are equal and therefore return same re-sults.

DataFrame.sample([withReplacement, . . . ]) Returns a sampled subset of this DataFrame.DataFrame.sampleBy(col, fractions[, seed]) Returns a stratified sample without replacement based

on the fraction given on each stratum.DataFrame.schema Returns the schema of this DataFrame as a

pyspark.sql.types.StructType.DataFrame.select(*cols) Projects a set of expressions and returns a new

DataFrame.DataFrame.selectExpr(*expr) Projects a set of SQL expressions and returns a new

DataFrame.DataFrame.semanticHash() Returns a hash code of the logical query plan against

this DataFrame.DataFrame.show([n, truncate, vertical]) Prints the first n rows to the console.DataFrame.sort(*cols, **kwargs) Returns a new DataFrame sorted by the specified col-

umn(s).DataFrame.sortWithinPartitions(*cols,**kwargs)

Returns a new DataFrame with each partition sortedby the specified column(s).

DataFrame.stat Returns a DataFrameStatFunctions for statisticfunctions.

DataFrame.storageLevel Get the DataFrame’s current storage level.DataFrame.subtract(other) Return a new DataFrame containing rows in this

DataFrame but not in another DataFrame.continues on next page

3.1. Spark SQL 71

Page 76: pyspark Documentation

pyspark Documentation, Release master

Table 16 – continued from previous pageDataFrame.summary(*statistics) Computes specified statistics for numeric and string

columns.DataFrame.tail(num) Returns the last num rows as a list of Row .DataFrame.take(num) Returns the first num rows as a list of Row .DataFrame.toDF(*cols) Returns a new DataFrame that with new specified col-

umn namesDataFrame.toJSON ([use_unicode]) Converts a DataFrame into a RDD of string.DataFrame.toLocalIterator([prefetchPartitions])Returns an iterator that contains all of the rows in this

DataFrame.DataFrame.toPandas() Returns the contents of this DataFrame as Pandas

pandas.DataFrame.DataFrame.transform(func) Returns a new DataFrame.DataFrame.union(other) Return a new DataFrame containing union of rows in

this and another DataFrame.DataFrame.unionAll(other) Return a new DataFrame containing union of rows in

this and another DataFrame.DataFrame.unionByName(other[, . . . ]) Returns a new DataFrame containing union of rows

in this and another DataFrame.DataFrame.unpersist([blocking]) Marks the DataFrame as non-persistent, and remove

all blocks for it from memory and disk.DataFrame.where(condition) where() is an alias for filter().DataFrame.withColumn(colName, col) Returns a new DataFrame by adding a column or re-

placing the existing column that has the same name.DataFrame.withColumnRenamed(existing, new) Returns a new DataFrame by renaming an existing

column.DataFrame.withWatermark(eventTime, . . . ) Defines an event time watermark for this DataFrame.DataFrame.write Interface for saving the content of the non-streaming

DataFrame out into external storage.DataFrame.writeStream Interface for saving the content of the streaming

DataFrame out into external storage.DataFrame.writeTo(table) Create a write configuration builder for v2 sources.DataFrameNaFunctions.drop([how, thresh,subset])

Returns a new DataFrame omitting rows with nullvalues.

DataFrameNaFunctions.fill(value[, subset]) Replace null values, alias for na.fill().DataFrameNaFunctions.replace(to_replace[,. . . ])

Returns a new DataFrame replacing a value with an-other value.

DataFrameStatFunctions.approxQuantile(col, . . . )

Calculates the approximate quantiles of numericalcolumns of a DataFrame.

DataFrameStatFunctions.corr(col1, col2[,method])

Calculates the correlation of two columns of aDataFrame as a double value.

DataFrameStatFunctions.cov(col1, col2) Calculate the sample covariance for the given columns,specified by their names, as a double value.

DataFrameStatFunctions.crosstab(col1,col2)

Computes a pair-wise frequency table of the givencolumns.

DataFrameStatFunctions.freqItems(cols[,support])

Finding frequent items for columns, possibly with falsepositives.

DataFrameStatFunctions.sampleBy(col,fractions)

Returns a stratified sample without replacement basedon the fraction given on each stratum.

72 Chapter 3. API Reference

Page 77: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.agg

DataFrame.agg(*exprs)Aggregate on the entire DataFrame without groups (shorthand for df.groupBy().agg()).

New in version 1.3.0.

Examples

>>> df.agg({"age": "max"}).collect()[Row(max(age)=5)]>>> from pyspark.sql import functions as F>>> df.agg(F.min(df.age)).collect()[Row(min(age)=2)]

pyspark.sql.DataFrame.alias

DataFrame.alias(alias)Returns a new DataFrame with an alias set.

New in version 1.3.0.

Parameters

alias [str] an alias name to be set for the DataFrame.

Examples

>>> from pyspark.sql.functions import *>>> df_as1 = df.alias("df_as1")>>> df_as2 = df.alias("df_as2")>>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"),→˓'inner')>>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age") .→˓sort(desc("df_as1.name")).collect()[Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice', age=2)]

pyspark.sql.DataFrame.approxQuantile

DataFrame.approxQuantile(col, probabilities, relativeError)Calculates the approximate quantiles of numerical columns of a DataFrame.

The result of this algorithm has the following deterministic bound: If the DataFrame has N elements andif we request the quantile at probability p up to error err, then the algorithm will return a sample x from theDataFrame so that the exact rank of x is close to (p * N). More precisely,

floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).

This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). Thealgorithm was first present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation ofQuantile Summaries]] by Greenwald and Khanna.

Note that null values will be ignored in numerical columns before calculation. For columns only containing nullvalues, an empty list is returned.

3.1. Spark SQL 73

Page 78: pyspark Documentation

pyspark Documentation, Release master

New in version 2.0.0.

Parameters

col: str, tuple or list Can be a single column name, or a list of names for multiple columns.

Changed in version 2.2: Added support for multiple columns.

probabilities [list or tuple] a list of quantile probabilities Each number must belong to [0, 1].For example 0 is the minimum, 0.5 is the median, 1 is the maximum.

relativeError [float] The relative target precision to achieve (>= 0). If set to zero, the exactquantiles are computed, which could be very expensive. Note that values greater than 1 areaccepted but give the same result as 1.

Returns

list the approximate quantiles at the given probabilities. If the input col is a string, the output isa list of floats. If the input col is a list or tuple of strings, the output is also a list, but eachelement in it is a list of floats, i.e., the output is a list of list of floats.

pyspark.sql.DataFrame.cache

DataFrame.cache()Persists the DataFrame with the default storage level (MEMORY_AND_DISK).

New in version 1.3.0.

Notes

The default storage level has changed to MEMORY_AND_DISK to match Scala in 2.0.

pyspark.sql.DataFrame.checkpoint

DataFrame.checkpoint(eager=True)Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the logical plan of thisDataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially. It willbe saved to files inside the checkpoint directory set with SparkContext.setCheckpointDir().

New in version 2.1.0.

Parameters

eager [bool, optional] Whether to checkpoint this DataFrame immediately

Notes

This API is experimental.

74 Chapter 3. API Reference

Page 79: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.coalesce

DataFrame.coalesce(numPartitions)Returns a new DataFrame that has exactly numPartitions partitions.

Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g. if you go from 1000partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of thecurrent partitions. If a larger number of partitions is requested, it will stay at the current number of partitions.

However, if you’re doing a drastic coalesce, e.g. to numPartitions = 1, this may result in your computationtaking place on fewer nodes than you like (e.g. one node in the case of numPartitions = 1). To avoid this, youcan call repartition(). This will add a shuffle step, but means the current upstream partitions will be executed inparallel (per whatever the current partitioning is).

New in version 1.4.0.

Parameters

numPartitions [int] specify the target number of partitions

Examples

>>> df.coalesce(1).rdd.getNumPartitions()1

pyspark.sql.DataFrame.colRegex

DataFrame.colRegex(colName)Selects column based on the column name specified as a regex and returns it as Column.

New in version 2.3.0.

Parameters

colName [str] string, column name specified as a regex.

Examples

>>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"])>>> df.select(df.colRegex("`(Col1)?+.+`")).show()+----+|Col2|+----+| 1|| 2|| 3|+----+

3.1. Spark SQL 75

Page 80: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.collect

DataFrame.collect()Returns all the records as a list of Row .

New in version 1.3.0.

Examples

>>> df.collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]

pyspark.sql.DataFrame.columns

property DataFrame.columnsReturns all column names as a list.

New in version 1.3.0.

Examples

>>> df.columns['age', 'name']

pyspark.sql.DataFrame.corr

DataFrame.corr(col1, col2, method=None)Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports thePearson Correlation Coefficient. DataFrame.corr() and DataFrameStatFunctions.corr() arealiases of each other.

New in version 1.4.0.

Parameters

col1 [str] The name of the first column

col2 [str] The name of the second column

method [str, optional] The correlation method. Currently only supports “pearson”

pyspark.sql.DataFrame.count

DataFrame.count()Returns the number of rows in this DataFrame.

New in version 1.3.0.

76 Chapter 3. API Reference

Page 81: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.count()2

pyspark.sql.DataFrame.cov

DataFrame.cov(col1, col2)Calculate the sample covariance for the given columns, specified by their names, as a double value.DataFrame.cov() and DataFrameStatFunctions.cov() are aliases.

New in version 1.4.0.

Parameters

col1 [str] The name of the first column

col2 [str] The name of the second column

pyspark.sql.DataFrame.createGlobalTempView

DataFrame.createGlobalTempView(name)Creates a global temporary view with this DataFrame.

The lifetime of this temporary view is tied to this Spark application. throwsTempTableAlreadyExistsException, if the view name already exists in the catalog.

New in version 2.1.0.

Examples

>>> df.createGlobalTempView("people")>>> df2 = spark.sql("select * from global_temp.people")>>> sorted(df.collect()) == sorted(df2.collect())True>>> df.createGlobalTempView("people")Traceback (most recent call last):...AnalysisException: u"Temporary table 'people' already exists;">>> spark.catalog.dropGlobalTempView("people")

pyspark.sql.DataFrame.createOrReplaceGlobalTempView

DataFrame.createOrReplaceGlobalTempView(name)Creates or replaces a global temporary view using the given name.

The lifetime of this temporary view is tied to this Spark application.

New in version 2.2.0.

3.1. Spark SQL 77

Page 82: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.createOrReplaceGlobalTempView("people")>>> df2 = df.filter(df.age > 3)>>> df2.createOrReplaceGlobalTempView("people")>>> df3 = spark.sql("select * from global_temp.people")>>> sorted(df3.collect()) == sorted(df2.collect())True>>> spark.catalog.dropGlobalTempView("people")

pyspark.sql.DataFrame.createOrReplaceTempView

DataFrame.createOrReplaceTempView(name)Creates or replaces a local temporary view with this DataFrame.

The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.

New in version 2.0.0.

Examples

>>> df.createOrReplaceTempView("people")>>> df2 = df.filter(df.age > 3)>>> df2.createOrReplaceTempView("people")>>> df3 = spark.sql("select * from people")>>> sorted(df3.collect()) == sorted(df2.collect())True>>> spark.catalog.dropTempView("people")

pyspark.sql.DataFrame.createTempView

DataFrame.createTempView(name)Creates a local temporary view with this DataFrame.

The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.throws TempTableAlreadyExistsException, if the view name already exists in the catalog.

New in version 2.0.0.

Examples

>>> df.createTempView("people")>>> df2 = spark.sql("select * from people")>>> sorted(df.collect()) == sorted(df2.collect())True>>> df.createTempView("people")Traceback (most recent call last):...AnalysisException: u"Temporary table 'people' already exists;">>> spark.catalog.dropTempView("people")

78 Chapter 3. API Reference

Page 83: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.crossJoin

DataFrame.crossJoin(other)Returns the cartesian product with another DataFrame.

New in version 2.1.0.

Parameters

other [DataFrame] Right side of the cartesian product.

Examples

>>> df.select("age", "name").collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df2.select("name", "height").collect()[Row(name='Tom', height=80), Row(name='Bob', height=85)]>>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect()[Row(age=2, name='Alice', height=80), Row(age=2, name='Alice', height=85),Row(age=5, name='Bob', height=80), Row(age=5, name='Bob', height=85)]

pyspark.sql.DataFrame.crosstab

DataFrame.crosstab(col1, col2)Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number ofdistinct values for each column should be less than 1e4. At most 1e6 non-zero pair frequencies will be returned.The first column of each row will be the distinct values of col1 and the column names will be the distinct valuesof col2. The name of the first column will be $col1_$col2. Pairs that have no occurrences will have zero as theircounts. DataFrame.crosstab() and DataFrameStatFunctions.crosstab() are aliases.

New in version 1.4.0.

Parameters

col1 [str] The name of the first column. Distinct items will make the first item of each row.

col2 [str] The name of the second column. Distinct items will make the column names of theDataFrame.

pyspark.sql.DataFrame.cube

DataFrame.cube(*cols)Create a multi-dimensional cube for the current DataFrame using the specified columns, so we can run aggre-gations on them.

New in version 1.4.0.

3.1. Spark SQL 79

Page 84: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.cube("name", df.age).count().orderBy("name", "age").show()+-----+----+-----+| name| age|count|+-----+----+-----+| null|null| 2|| null| 2| 1|| null| 5| 1||Alice|null| 1||Alice| 2| 1|| Bob|null| 1|| Bob| 5| 1|+-----+----+-----+

pyspark.sql.DataFrame.describe

DataFrame.describe(*cols)Computes basic statistics for numeric and string columns.

New in version 1.3.1.

This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics forall numerical or string columns.

See also:

DataFrame.summary

Notes

This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.

Use summary for expanded statistics and control over which statistics to compute.

Examples

>>> df.describe(['age']).show()+-------+------------------+|summary| age|+-------+------------------+| count| 2|| mean| 3.5|| stddev|2.1213203435596424|| min| 2|| max| 5|+-------+------------------+>>> df.describe().show()+-------+------------------+-----+|summary| age| name|+-------+------------------+-----+| count| 2| 2|| mean| 3.5| null|

(continues on next page)

80 Chapter 3. API Reference

Page 85: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

| stddev|2.1213203435596424| null|| min| 2|Alice|| max| 5| Bob|+-------+------------------+-----+

pyspark.sql.DataFrame.distinct

DataFrame.distinct()Returns a new DataFrame containing the distinct rows in this DataFrame.

New in version 1.3.0.

Examples

>>> df.distinct().count()2

pyspark.sql.DataFrame.drop

DataFrame.drop(*cols)Returns a new DataFrame that drops the specified column. This is a no-op if schema doesn’t contain the givencolumn name(s).

New in version 1.4.0.

Parameters

cols: str or :class:`Column` a name of the column, or the Column to drop

Examples

>>> df.drop('age').collect()[Row(name='Alice'), Row(name='Bob')]

>>> df.drop(df.age).collect()[Row(name='Alice'), Row(name='Bob')]

>>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()[Row(age=5, height=85, name='Bob')]

>>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()[Row(age=5, name='Bob', height=85)]

>>> df.join(df2, 'name', 'inner').drop('age', 'height').collect()[Row(name='Bob')]

3.1. Spark SQL 81

Page 86: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.dropDuplicates

DataFrame.dropDuplicates(subset=None)Return a new DataFrame with duplicate rows removed, optionally only considering certain columns.

For a static batch DataFrame, it just drops duplicate rows. For a streaming DataFrame, it will keep all dataacross triggers as intermediate state to drop duplicates rows. You can use withWatermark() to limit howlate the duplicate data can be and system will accordingly limit the state. In addition, too late data older thanwatermark will be dropped to avoid any possibility of duplicates.

drop_duplicates() is an alias for dropDuplicates().

New in version 1.4.0.

Examples

>>> from pyspark.sql import Row>>> df = sc.parallelize([ \... Row(name='Alice', age=5, height=80), \... Row(name='Alice', age=5, height=80), \... Row(name='Alice', age=10, height=80)]).toDF()>>> df.dropDuplicates().show()+-----+---+------+| name|age|height|+-----+---+------+|Alice| 5| 80||Alice| 10| 80|+-----+---+------+

>>> df.dropDuplicates(['name', 'height']).show()+-----+---+------+| name|age|height|+-----+---+------+|Alice| 5| 80|+-----+---+------+

pyspark.sql.DataFrame.drop_duplicates

DataFrame.drop_duplicates(subset=None)drop_duplicates() is an alias for dropDuplicates().

New in version 1.4.

pyspark.sql.DataFrame.dropna

DataFrame.dropna(how='any', thresh=None, subset=None)Returns a new DataFrame omitting rows with null values. DataFrame.dropna() andDataFrameNaFunctions.drop() are aliases of each other.

New in version 1.3.1.

Parameters

how [str, optional] ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a rowonly if all its values are null.

82 Chapter 3. API Reference

Page 87: pyspark Documentation

pyspark Documentation, Release master

thresh: int, optional default None If specified, drop rows that have less than thresh non-nullvalues. This overwrites the how parameter.

subset [str, tuple or list, optional] optional list of column names to consider.

Examples

>>> df4.na.drop().show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|+---+------+-----+

pyspark.sql.DataFrame.dtypes

property DataFrame.dtypesReturns all column names and their data types as a list.

New in version 1.3.0.

Examples

>>> df.dtypes[('age', 'int'), ('name', 'string')]

pyspark.sql.DataFrame.exceptAll

DataFrame.exceptAll(other)Return a new DataFrame containing rows in this DataFrame but not in another DataFrame while preserv-ing duplicates.

This is equivalent to EXCEPT ALL in SQL. As standard in SQL, this function resolves columns by position (notby name).

New in version 2.4.0.

Examples

>>> df1 = spark.createDataFrame(... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1",→˓"C2"])>>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"])

>>> df1.exceptAll(df2).show()+---+---+| C1| C2|+---+---+| a| 1|| a| 1|| a| 2|

(continues on next page)

3.1. Spark SQL 83

Page 88: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

| c| 4|+---+---+

pyspark.sql.DataFrame.explain

DataFrame.explain(extended=None, mode=None)Prints the (logical and physical) plans to the console for debugging purpose.

New in version 1.3.0.

Parameters

extended [bool, optional] default False. If False, prints only the physical plan. When thisis a string without specifying the mode, it works as the mode is specified.

mode [str, optional] specifies the expected output format of plans.

• simple: Print only a physical plan.

• extended: Print both logical and physical plans.

• codegen: Print a physical plan and generated codes if they are available.

• cost: Print a logical plan and statistics if they are available.

• formatted: Split explain output into two sections: a physical plan outline and nodedetails.

Changed in version 3.0.0: Added optional argument mode to specify the expected outputformat of plans.

Examples

>>> df.explain()== Physical Plan ==

*(1) Scan ExistingRDD[age#0,name#1]

>>> df.explain(True)== Parsed Logical Plan ==...== Analyzed Logical Plan ==...== Optimized Logical Plan ==...== Physical Plan ==...

>>> df.explain(mode="formatted")== Physical Plan ==

* Scan ExistingRDD (1)(1) Scan ExistingRDD [codegen id : 1]Output [2]: [age#0, name#1]...

84 Chapter 3. API Reference

Page 89: pyspark Documentation

pyspark Documentation, Release master

>>> df.explain("cost")== Optimized Logical Plan ==...Statistics......

pyspark.sql.DataFrame.fillna

DataFrame.fillna(value, subset=None)Replace null values, alias for na.fill(). DataFrame.fillna() and DataFrameNaFunctions.fill() are aliases of each other.

New in version 1.3.1.

Parameters

value [int, float, string, bool or dict] Value to replace null values with. If the value is a dict, thensubset is ignored and value must be a mapping from column name (string) to replacementvalue. The replacement value must be an int, float, boolean, or string.

subset [str, tuple or list, optional] optional list of column names to consider. Columns specifiedin subset that do not have matching data type are ignored. For example, if value is a string,and subset contains a non-string column, then the non-string column is simply ignored.

Examples

>>> df4.na.fill(50).show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|| 5| 50| Bob|| 50| 50| Tom|| 50| 50| null|+---+------+-----+

>>> df5.na.fill(False).show()+----+-------+-----+| age| name| spy|+----+-------+-----+| 10| Alice|false|| 5| Bob|false||null|Mallory| true|+----+-------+-----+

>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()+---+------+-------+|age|height| name|+---+------+-------+| 10| 80| Alice|| 5| null| Bob|| 50| null| Tom|| 50| null|unknown|+---+------+-------+

3.1. Spark SQL 85

Page 90: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.filter

DataFrame.filter(condition)Filters rows using the given condition.

where() is an alias for filter().

New in version 1.3.0.

Parameters

condition [Column or str] a Column of types.BooleanType or a string of SQL expres-sion.

Examples

>>> df.filter(df.age > 3).collect()[Row(age=5, name='Bob')]>>> df.where(df.age == 2).collect()[Row(age=2, name='Alice')]

>>> df.filter("age > 3").collect()[Row(age=5, name='Bob')]>>> df.where("age = 2").collect()[Row(age=2, name='Alice')]

pyspark.sql.DataFrame.first

DataFrame.first()Returns the first row as a Row .

New in version 1.3.0.

Examples

>>> df.first()Row(age=2, name='Alice')

pyspark.sql.DataFrame.foreach

DataFrame.foreach(f)Applies the f function to all Row of this DataFrame.

This is a shorthand for df.rdd.foreach().

New in version 1.3.0.

86 Chapter 3. API Reference

Page 91: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> def f(person):... print(person.name)>>> df.foreach(f)

pyspark.sql.DataFrame.foreachPartition

DataFrame.foreachPartition(f)Applies the f function to each partition of this DataFrame.

This a shorthand for df.rdd.foreachPartition().

New in version 1.3.0.

Examples

>>> def f(people):... for person in people:... print(person.name)>>> df.foreachPartition(f)

pyspark.sql.DataFrame.freqItems

DataFrame.freqItems(cols, support=None)Finding frequent items for columns, possibly with false positives. Using the frequent element count algo-rithm described in “https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou”.DataFrame.freqItems() and DataFrameStatFunctions.freqItems() are aliases.

New in version 1.4.0.

Parameters

cols [list or tuple] Names of the columns to calculate frequent items for as a list or tuple ofstrings.

support [float, optional] The frequency with which to consider an item ‘frequent’. Default is1%. The support must be greater than 1e-4.

Notes

This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.

3.1. Spark SQL 87

Page 92: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.groupBy

DataFrame.groupBy(*cols)Groups the DataFrame using the specified columns, so we can run aggregation on them. See GroupedDatafor all the available aggregate functions.

groupby() is an alias for groupBy().

New in version 1.3.0.

Parameters

cols [list, str or Column] columns to group by. Each element should be a column name (string)or an expression (Column).

Examples

>>> df.groupBy().avg().collect()[Row(avg(age)=3.5)]>>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect())[Row(name='Alice', avg(age)=2.0), Row(name='Bob', avg(age)=5.0)]>>> sorted(df.groupBy(df.name).avg().collect())[Row(name='Alice', avg(age)=2.0), Row(name='Bob', avg(age)=5.0)]>>> sorted(df.groupBy(['name', df.age]).count().collect())[Row(name='Alice', age=2, count=1), Row(name='Bob', age=5, count=1)]

pyspark.sql.DataFrame.head

DataFrame.head(n=None)Returns the first n rows.

New in version 1.3.0.

Parameters

n [int, optional] default 1. Number of rows to return.

Returns

If n is greater than 1, return a list of Row .

If n is 1, return a single Row.

Notes

This method should only be used if the resulting array is expected to be small, as all the data is loaded into thedriver’s memory.

88 Chapter 3. API Reference

Page 93: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.head()Row(age=2, name='Alice')>>> df.head(1)[Row(age=2, name='Alice')]

pyspark.sql.DataFrame.hint

DataFrame.hint(name, *parameters)Specifies some hint on the current DataFrame.

New in version 2.2.0.

Parameters

name [str] A name of the hint.

parameters [str, list, float or int] Optional parameters.

Returns

DataFrame

Examples

>>> df.join(df2.hint("broadcast"), "name").show()+----+---+------+|name|age|height|+----+---+------+| Bob| 5| 85|+----+---+------+

pyspark.sql.DataFrame.inputFiles

DataFrame.inputFiles()Returns a best-effort snapshot of the files that compose this DataFrame. This method simply asks eachconstituent BaseRelation for its respective files and takes the union of all results. Depending on the sourcerelations, this may not find all input files. Duplicates are removed.

New in version 3.1.0.

Examples

>>> df = spark.read.load("examples/src/main/resources/people.json", format="json")>>> len(df.inputFiles())1

3.1. Spark SQL 89

Page 94: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.intersect

DataFrame.intersect(other)Return a new DataFrame containing rows only in both this DataFrame and another DataFrame.

This is equivalent to INTERSECT in SQL.

New in version 1.3.

pyspark.sql.DataFrame.intersectAll

DataFrame.intersectAll(other)Return a new DataFrame containing rows in both this DataFrame and another DataFrame while preserv-ing duplicates.

This is equivalent to INTERSECT ALL in SQL. As standard in SQL, this function resolves columns by position(not by name).

New in version 2.4.0.

Examples

>>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1",→˓"C2"])>>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"])

>>> df1.intersectAll(df2).sort("C1", "C2").show()+---+---+| C1| C2|+---+---+| a| 1|| a| 1|| b| 3|+---+---+

pyspark.sql.DataFrame.isLocal

DataFrame.isLocal()Returns True if the collect() and take() methods can be run locally (without any Spark executors).

New in version 1.3.

pyspark.sql.DataFrame.isStreaming

property DataFrame.isStreamingReturns True if this Dataset contains one or more sources that continuously return data as it arrives.A Dataset that reads data from a streaming source must be executed as a StreamingQuery usingthe start() method in DataStreamWriter. Methods that return a single answer, (e.g., count() orcollect()) will throw an AnalysisException when there is a streaming source present.

New in version 2.0.0.

90 Chapter 3. API Reference

Page 95: pyspark Documentation

pyspark Documentation, Release master

Notes

This API is evolving.

pyspark.sql.DataFrame.join

DataFrame.join(other, on=None, how=None)Joins with another DataFrame, using the given join expression.

New in version 1.3.0.

Parameters

other [DataFrame] Right side of the join

on [str, list or Column, optional] a string for the join column name, a list of column names, ajoin expression (Column), or a list of Columns. If on is a string or a list of strings indicatingthe name of the join column(s), the column(s) must exist on both sides, and this performsan equi-join.

how [str, optional] default inner. Must be one of: inner, cross, outer,full, fullouter, full_outer, left, leftouter, left_outer, right,rightouter, right_outer, semi, leftsemi, left_semi, anti, leftantiand left_anti.

Examples

The following performs a full outer join between df1 and df2. >>> from pyspark.sql.functions importdesc >>> df.join(df2, df.name == df2.name, ‘outer’).select(df.name, df2.height) .sort(desc(“name”)).collect()[Row(name=’Bob’, height=85), Row(name=’Alice’, height=None), Row(name=None, height=80)]

>>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).→˓collect()[Row(name='Tom', height=80), Row(name='Bob', height=85), Row(name='Alice',→˓height=None)]

>>> cond = [df.name == df3.name, df.age == df3.age]>>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()[Row(name='Alice', age=2), Row(name='Bob', age=5)]

>>> df.join(df2, 'name').select(df.name, df2.height).collect()[Row(name='Bob', height=85)]

>>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()[Row(name='Bob', age=5)]

3.1. Spark SQL 91

Page 96: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.limit

DataFrame.limit(num)Limits the result count to the number specified.

New in version 1.3.0.

Examples

>>> df.limit(1).collect()[Row(age=2, name='Alice')]>>> df.limit(0).collect()[]

pyspark.sql.DataFrame.localCheckpoint

DataFrame.localCheckpoint(eager=True)Returns a locally checkpointed version of this Dataset. Checkpointing can be used to truncate the logical planof this DataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially.Local checkpoints are stored in the executors using the caching subsystem and therefore they are not reliable.

New in version 2.3.0.

Parameters

eager [bool, optional] Whether to checkpoint this DataFrame immediately

Notes

This API is experimental.

pyspark.sql.DataFrame.mapInPandas

DataFrame.mapInPandas(func, schema)Maps an iterator of batches in the current DataFrame using a Python native function that takes and outputs apandas DataFrame, and returns the result as a DataFrame.

The function should take an iterator of pandas.DataFrames and return another iterator of pandas.DataFrames.All columns are passed together as an iterator of pandas.DataFrames to the function and the returned iteratorof pandas.DataFrames are combined as a DataFrame. Each pandas.DataFrame size can be controlled byspark.sql.execution.arrow.maxRecordsPerBatch.

New in version 3.0.0.

Parameters

func [function] a Python native function that takes an iterator of pandas.DataFrames, and out-puts an iterator of pandas.DataFrames.

schema [pyspark.sql.types.DataType or str] the return type of the func in PySpark.The value can be either a pyspark.sql.types.DataType object or a DDL-formattedtype string.

See also:

pyspark.sql.functions.pandas_udf

92 Chapter 3. API Reference

Page 97: pyspark Documentation

pyspark Documentation, Release master

Notes

This API is experimental

Examples

>>> from pyspark.sql.functions import pandas_udf>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))>>> def filter_func(iterator):... for pdf in iterator:... yield pdf[pdf.id == 1]>>> df.mapInPandas(filter_func, df.schema).show()+---+---+| id|age|+---+---+| 1| 21|+---+---+

pyspark.sql.DataFrame.na

property DataFrame.naReturns a DataFrameNaFunctions for handling missing values.

New in version 1.3.1.

pyspark.sql.DataFrame.orderBy

DataFrame.orderBy(*cols, **kwargs)Returns a new DataFrame sorted by the specified column(s).

New in version 1.3.0.

Parameters

cols [str, list, or Column, optional] list of Column or column names to sort by.

Other Parameters

ascending [bool or list, optional] boolean or list of boolean (default True). Sort ascending vs.descending. Specify list for multiple sort orders. If a list is specified, length of the list mustequal length of the cols.

Examples

>>> df.sort(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.sort("age", ascending=False).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.orderBy(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> from pyspark.sql.functions import *>>> df.sort(asc("age")).collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df.orderBy(desc("age"), "name").collect()

(continues on next page)

3.1. Spark SQL 93

Page 98: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]

pyspark.sql.DataFrame.persist

DataFrame.persist(storageLevel=StorageLevel(True, True, False, True, 1))Sets the storage level to persist the contents of the DataFrame across operations after the first time it iscomputed. This can only be used to assign a new storage level if the DataFrame does not have a storage levelset yet. If no storage level is specified defaults to (MEMORY_AND_DISK_DESER)

New in version 1.3.0.

Notes

The default storage level has changed to MEMORY_AND_DISK_DESER to match Scala in 3.0.

pyspark.sql.DataFrame.printSchema

DataFrame.printSchema()Prints out the schema in the tree format.

New in version 1.3.0.

Examples

>>> df.printSchema()root|-- age: integer (nullable = true)|-- name: string (nullable = true)

pyspark.sql.DataFrame.randomSplit

DataFrame.randomSplit(weights, seed=None)Randomly splits this DataFrame with the provided weights.

New in version 1.4.0.

Parameters

weights [list] list of doubles as weights with which to split the DataFrame. Weights will benormalized if they don’t sum up to 1.0.

seed [int, optional] The seed for sampling.

94 Chapter 3. API Reference

Page 99: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> splits = df4.randomSplit([1.0, 2.0], 24)>>> splits[0].count()2

>>> splits[1].count()2

pyspark.sql.DataFrame.rdd

property DataFrame.rddReturns the content as an pyspark.RDD of Row .

New in version 1.3.

pyspark.sql.DataFrame.registerTempTable

DataFrame.registerTempTable(name)Registers this DataFrame as a temporary table using the given name.

The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.

New in version 1.3.0.

Deprecated since version 2.0.0: Use DataFrame.createOrReplaceTempView() instead.

Examples

>>> df.registerTempTable("people")>>> df2 = spark.sql("select * from people")>>> sorted(df.collect()) == sorted(df2.collect())True>>> spark.catalog.dropTempView("people")

pyspark.sql.DataFrame.repartition

DataFrame.repartition(numPartitions, *cols)Returns a new DataFrame partitioned by the given partitioning expressions. The resulting DataFrame ishash partitioned.

New in version 1.3.0.

Parameters

numPartitions [int] can be an int to specify the target number of partitions or a Column. Ifit is a Column, it will be used as the first partitioning column. If not specified, the defaultnumber of partitions is used.

cols [str or Column] partitioning columns.

Changed in version 1.6: Added optional arguments to specify the partitioning columns. Alsomade numPartitions optional if partitioning columns are specified.

3.1. Spark SQL 95

Page 100: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.repartition(10).rdd.getNumPartitions()10>>> data = df.union(df).repartition("age")>>> data.show()+---+-----+|age| name|+---+-----+| 5| Bob|| 5| Bob|| 2|Alice|| 2|Alice|+---+-----+>>> data = data.repartition(7, "age")>>> data.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|| 2|Alice|| 5| Bob|+---+-----+>>> data.rdd.getNumPartitions()7>>> data = data.repartition("name", "age")>>> data.show()+---+-----+|age| name|+---+-----+| 5| Bob|| 5| Bob|| 2|Alice|| 2|Alice|+---+-----+

pyspark.sql.DataFrame.repartitionByRange

DataFrame.repartitionByRange(numPartitions, *cols)Returns a new DataFrame partitioned by the given partitioning expressions. The resulting DataFrame isrange partitioned.

At least one partition-by expression must be specified. When no explicit sort order is specified, “ascending nullsfirst” is assumed.

New in version 2.4.0.

Parameters

numPartitions [int] can be an int to specify the target number of partitions or a Column. Ifit is a Column, it will be used as the first partitioning column. If not specified, the defaultnumber of partitions is used.

cols [str or Column] partitioning columns.

96 Chapter 3. API Reference

Page 101: pyspark Documentation

pyspark Documentation, Release master

Notes

Due to performance reasons this method uses sampling to estimate the ranges. Hence, the output may notbe consistent, since sampling can return different values. The sample size can be controlled by the configspark.sql.execution.rangeExchange.sampleSizePerPartition.

Examples

>>> df.repartitionByRange(2, "age").rdd.getNumPartitions()2>>> df.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+>>> df.repartitionByRange(1, "age").rdd.getNumPartitions()1>>> data = df.repartitionByRange("age")>>> df.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+

pyspark.sql.DataFrame.replace

DataFrame.replace(to_replace, value=<no value>, subset=None)Returns a new DataFrame replacing a value with another value. DataFrame.replace() andDataFrameNaFunctions.replace() are aliases of each other. Values to_replace and value must havethe same type and can only be numerics, booleans, or strings. Value can have None. When replacing, the newvalue will be cast to the type of the existing column. For numeric replacements all values to be replaced shouldhave unique floating point representation. In case of conflicts (for example with {42: -1, 42.0: 1}) and arbitraryreplacement will be used.

New in version 1.4.0.

Parameters

to_replace [bool, int, float, string, list or dict] Value to be replaced. If the value is a dict, thenvalue is ignored or can be omitted, and to_replace must be a mapping between a value anda replacement.

value [bool, int, float, string or None, optional] The replacement value must be a bool, int, float,string or None. If value is a list, value should be of the same length and type as to_replace.If value is a scalar and to_replace is a sequence, then value is used as a replacement for eachitem in to_replace.

subset [list, optional] optional list of column names to consider. Columns specified in subsetthat do not have matching data type are ignored. For example, if value is a string, and subsetcontains a non-string column, then the non-string column is simply ignored.

3.1. Spark SQL 97

Page 102: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df4.na.replace(10, 20).show()+----+------+-----+| age|height| name|+----+------+-----+| 20| 80|Alice|| 5| null| Bob||null| null| Tom||null| null| null|+----+------+-----+

>>> df4.na.replace('Alice', None).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+

>>> df4.na.replace({'Alice': None}).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+

>>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()+----+------+----+| age|height|name|+----+------+----+| 10| 80| A|| 5| null| B||null| null| Tom||null| null|null|+----+------+----+

98 Chapter 3. API Reference

Page 103: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.rollup

DataFrame.rollup(*cols)Create a multi-dimensional rollup for the current DataFrame using the specified columns, so we can runaggregation on them.

New in version 1.4.0.

Examples

>>> df.rollup("name", df.age).count().orderBy("name", "age").show()+-----+----+-----+| name| age|count|+-----+----+-----+| null|null| 2||Alice|null| 1||Alice| 2| 1|| Bob|null| 1|| Bob| 5| 1|+-----+----+-----+

pyspark.sql.DataFrame.sameSemantics

DataFrame.sameSemantics(other)Returns True when the logical query plans inside both DataFrames are equal and therefore return same results.

New in version 3.1.0.

Notes

The equality comparison here is simplified by tolerating the cosmetic differences such as attribute names.

This API can compare both DataFrames very fast but can still return False on the DataFrame that return thesame results, for instance, from different plans. Such false negative semantic can be useful when caching as anexample.

This API is a developer API.

Examples

>>> df1 = spark.range(10)>>> df2 = spark.range(10)>>> df1.withColumn("col1", df1.id * 2).sameSemantics(df2.withColumn("col1", df2.→˓id * 2))True>>> df1.withColumn("col1", df1.id * 2).sameSemantics(df2.withColumn("col1", df2.→˓id + 2))False>>> df1.withColumn("col1", df1.id * 2).sameSemantics(df2.withColumn("col0", df2.→˓id * 2))True

3.1. Spark SQL 99

Page 104: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.sample

DataFrame.sample(withReplacement=None, fraction=None, seed=None)Returns a sampled subset of this DataFrame.

New in version 1.3.0.

Parameters

withReplacement [bool, optional] Sample with replacement or not (default False).

fraction [float, optional] Fraction of rows to generate, range [0.0, 1.0].

seed [int, optional] Seed for sampling (default a random seed).

Notes

This is not guaranteed to provide exactly the fraction specified of the total count of the given DataFrame.

fraction is required and, withReplacement and seed are optional.

Examples

>>> df = spark.range(10)>>> df.sample(0.5, 3).count()7>>> df.sample(fraction=0.5, seed=3).count()7>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()1>>> df.sample(1.0).count()10>>> df.sample(fraction=1.0).count()10>>> df.sample(False, fraction=1.0).count()10

pyspark.sql.DataFrame.sampleBy

DataFrame.sampleBy(col, fractions, seed=None)Returns a stratified sample without replacement based on the fraction given on each stratum.

New in version 1.5.0.

Parameters

col [Column or str] column that defines strata

Changed in version 3.0: Added sampling by a column of Column

fractions [dict] sampling fraction for each stratum. If a stratum is not specified, we treat itsfraction as zero.

seed [int, optional] random seed

Returns

a new DataFrame that represents the stratified sample

100 Chapter 3. API Reference

Page 105: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql.functions import col>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)>>> sampled.groupBy("key").count().orderBy("key").show()+---+-----+|key|count|+---+-----+| 0| 3|| 1| 6|+---+-----+>>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()33

pyspark.sql.DataFrame.schema

property DataFrame.schemaReturns the schema of this DataFrame as a pyspark.sql.types.StructType.

New in version 1.3.0.

Examples

>>> df.schemaStructType(List(StructField(age,IntegerType,true),StructField(name,StringType,→˓true)))

pyspark.sql.DataFrame.select

DataFrame.select(*cols)Projects a set of expressions and returns a new DataFrame.

New in version 1.3.0.

Parameters

cols [str, Column, or list] column names (string) or expressions (Column). If one of thecolumn names is ‘*’, that column is expanded to include all columns in the currentDataFrame.

Examples

>>> df.select('*').collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df.select('name', 'age').collect()[Row(name='Alice', age=2), Row(name='Bob', age=5)]>>> df.select(df.name, (df.age + 10).alias('age')).collect()[Row(name='Alice', age=12), Row(name='Bob', age=15)]

3.1. Spark SQL 101

Page 106: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.selectExpr

DataFrame.selectExpr(*expr)Projects a set of SQL expressions and returns a new DataFrame.

This is a variant of select() that accepts SQL expressions.

New in version 1.3.0.

Examples

>>> df.selectExpr("age * 2", "abs(age)").collect()[Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]

pyspark.sql.DataFrame.semanticHash

DataFrame.semanticHash()Returns a hash code of the logical query plan against this DataFrame.

New in version 3.1.0.

Notes

Unlike the standard hash code, the hash is calculated against the query plan simplified by tolerating the cosmeticdifferences such as attribute names.

This API is a developer API.

Examples

>>> spark.range(10).selectExpr("id as col0").semanticHash()1855039936>>> spark.range(10).selectExpr("id as col1").semanticHash()1855039936

pyspark.sql.DataFrame.show

DataFrame.show(n=20, truncate=True, vertical=False)Prints the first n rows to the console.

New in version 1.3.0.

Parameters

n [int, optional] Number of rows to show.

truncate [bool, optional] If set to True, truncate strings longer than 20 chars by default. If setto a number greater than one, truncates long strings to length truncate and align cellsright.

vertical [bool, optional] If set to True, print output rows vertically (one line per column value).

102 Chapter 3. API Reference

Page 107: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> dfDataFrame[age: int, name: string]>>> df.show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+>>> df.show(truncate=3)+---+----+|age|name|+---+----+| 2| Ali|| 5| Bob|+---+----+>>> df.show(vertical=True)-RECORD 0-----age | 2name | Alice

-RECORD 1-----age | 5name | Bob

pyspark.sql.DataFrame.sort

DataFrame.sort(*cols, **kwargs)Returns a new DataFrame sorted by the specified column(s).

New in version 1.3.0.

Parameters

cols [str, list, or Column, optional] list of Column or column names to sort by.

Other Parameters

ascending [bool or list, optional] boolean or list of boolean (default True). Sort ascending vs.descending. Specify list for multiple sort orders. If a list is specified, length of the list mustequal length of the cols.

Examples

>>> df.sort(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.sort("age", ascending=False).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> df.orderBy(df.age.desc()).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]>>> from pyspark.sql.functions import *>>> df.sort(asc("age")).collect()[Row(age=2, name='Alice'), Row(age=5, name='Bob')]>>> df.orderBy(desc("age"), "name").collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]

(continues on next page)

3.1. Spark SQL 103

Page 108: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()[Row(age=5, name='Bob'), Row(age=2, name='Alice')]

pyspark.sql.DataFrame.sortWithinPartitions

DataFrame.sortWithinPartitions(*cols, **kwargs)Returns a new DataFrame with each partition sorted by the specified column(s).

New in version 1.6.0.

Parameters

cols [str, list or Column, optional] list of Column or column names to sort by.

Other Parameters

ascending [bool or list, optional] boolean or list of boolean (default True). Sort ascending vs.descending. Specify list for multiple sort orders. If a list is specified, length of the list mustequal length of the cols.

Examples

>>> df.sortWithinPartitions("age", ascending=False).show()+---+-----+|age| name|+---+-----+| 2|Alice|| 5| Bob|+---+-----+

pyspark.sql.DataFrame.stat

property DataFrame.statReturns a DataFrameStatFunctions for statistic functions.

New in version 1.4.

pyspark.sql.DataFrame.storageLevel

property DataFrame.storageLevelGet the DataFrame’s current storage level.

New in version 2.1.0.

104 Chapter 3. API Reference

Page 109: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.storageLevelStorageLevel(False, False, False, False, 1)>>> df.cache().storageLevelStorageLevel(True, True, False, True, 1)>>> df2.persist(StorageLevel.DISK_ONLY_2).storageLevelStorageLevel(True, False, False, False, 2)

pyspark.sql.DataFrame.subtract

DataFrame.subtract(other)Return a new DataFrame containing rows in this DataFrame but not in another DataFrame.

This is equivalent to EXCEPT DISTINCT in SQL.

New in version 1.3.

pyspark.sql.DataFrame.summary

DataFrame.summary(*statistics)Computes specified statistics for numeric and string columns. Available statistics are: - count - mean - stddev -min - max - arbitrary approximate percentiles specified as a percentage (e.g., 75%)

If no statistics are given, this function computes count, mean, stddev, min, approximate quartiles (percentiles at25%, 50%, and 75%), and max.

New in version 2.3.0.

See also:

DataFrame.display

Notes

This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.

Examples

>>> df.summary().show()+-------+------------------+-----+|summary| age| name|+-------+------------------+-----+| count| 2| 2|| mean| 3.5| null|| stddev|2.1213203435596424| null|| min| 2|Alice|| 25%| 2| null|| 50%| 2| null|| 75%| 5| null|| max| 5| Bob|+-------+------------------+-----+

3.1. Spark SQL 105

Page 110: pyspark Documentation

pyspark Documentation, Release master

>>> df.summary("count", "min", "25%", "75%", "max").show()+-------+---+-----+|summary|age| name|+-------+---+-----+| count| 2| 2|| min| 2|Alice|| 25%| 2| null|| 75%| 5| null|| max| 5| Bob|+-------+---+-----+

To do a summary for specific columns first select them:

>>> df.select("age", "name").summary("count").show()+-------+---+----+|summary|age|name|+-------+---+----+| count| 2| 2|+-------+---+----+

pyspark.sql.DataFrame.tail

DataFrame.tail(num)Returns the last num rows as a list of Row .

Running tail requires moving data into the application’s driver process, and doing so with a very large num cancrash the driver process with OutOfMemoryError.

New in version 3.0.0.

Examples

>>> df.tail(1)[Row(age=5, name='Bob')]

pyspark.sql.DataFrame.take

DataFrame.take(num)Returns the first num rows as a list of Row .

New in version 1.3.0.

Examples

>>> df.take(2)[Row(age=2, name='Alice'), Row(age=5, name='Bob')]

106 Chapter 3. API Reference

Page 111: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.toDF

DataFrame.toDF(*cols)Returns a new DataFrame that with new specified column names

Parameters

cols [str] new column names

Examples

>>> df.toDF('f1', 'f2').collect()[Row(f1=2, f2='Alice'), Row(f1=5, f2='Bob')]

pyspark.sql.DataFrame.toJSON

DataFrame.toJSON(use_unicode=True)Converts a DataFrame into a RDD of string.

Each row is turned into a JSON document as one element in the returned RDD.

New in version 1.3.0.

Examples

>>> df.toJSON().first()'{"age":2,"name":"Alice"}'

pyspark.sql.DataFrame.toLocalIterator

DataFrame.toLocalIterator(prefetchPartitions=False)Returns an iterator that contains all of the rows in this DataFrame. The iterator will consume as much memoryas the largest partition in this DataFrame. With prefetch it may consume up to the memory of the 2 largestpartitions.

New in version 2.0.0.

Parameters

prefetchPartitions [bool, optional] If Spark should pre-fetch the next partition before it isneeded.

Examples

>>> list(df.toLocalIterator())[Row(age=2, name='Alice'), Row(age=5, name='Bob')]

3.1. Spark SQL 107

Page 112: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.toPandas

DataFrame.toPandas()Returns the contents of this DataFrame as Pandas pandas.DataFrame.

This is only available if Pandas is installed and available.

New in version 1.3.0.

Notes

This method should only be used if the resulting Pandas’s DataFrame is expected to be small, as all the datais loaded into the driver’s memory.

Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.

Examples

>>> df.toPandas()age name

0 2 Alice1 5 Bob

pyspark.sql.DataFrame.transform

DataFrame.transform(func)Returns a new DataFrame. Concise syntax for chaining custom transformations.

New in version 3.0.0.

Parameters

func [function] a function that takes and returns a DataFrame.

Examples

>>> from pyspark.sql.functions import col>>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"])>>> def cast_all_to_int(input_df):... return input_df.select([col(col_name).cast("int") for col_name in input_→˓df.columns])>>> def sort_columns_asc(input_df):... return input_df.select(*sorted(input_df.columns))>>> df.transform(cast_all_to_int).transform(sort_columns_asc).show()+-----+---+|float|int|+-----+---+| 1| 1|| 2| 2|+-----+---+

108 Chapter 3. API Reference

Page 113: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame.union

DataFrame.union(other)Return a new DataFrame containing union of rows in this and another DataFrame.

This is equivalent to UNION ALL in SQL. To do a SQL-style set union (that does deduplication of elements),use this function followed by distinct().

Also as standard in SQL, this function resolves columns by position (not by name).

New in version 2.0.

pyspark.sql.DataFrame.unionAll

DataFrame.unionAll(other)Return a new DataFrame containing union of rows in this and another DataFrame.

This is equivalent to UNION ALL in SQL. To do a SQL-style set union (that does deduplication of elements),use this function followed by distinct().

Also as standard in SQL, this function resolves columns by position (not by name).

New in version 1.3.

pyspark.sql.DataFrame.unionByName

DataFrame.unionByName(other, allowMissingColumns=False)Returns a new DataFrame containing union of rows in this and another DataFrame.

This is different from both UNION ALL and UNION DISTINCT in SQL. To do a SQL-style set union (that doesdeduplication of elements), use this function followed by distinct().

New in version 2.3.0.

Examples

The difference between this function and union() is that this function resolves columns by name (not byposition):

>>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"])>>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"])>>> df1.unionByName(df2).show()+----+----+----+|col0|col1|col2|+----+----+----+| 1| 2| 3|| 6| 4| 5|+----+----+----+

When the parameter allowMissingColumns is True, the set of column names in this and other DataFramecan differ; missing columns will be filled with null. Further, the missing columns of this DataFrame will beadded at the end in the schema of the union result:

>>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"])>>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col3"])>>> df1.unionByName(df2, allowMissingColumns=True).show()

(continues on next page)

3.1. Spark SQL 109

Page 114: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+----+----+----+----+|col0|col1|col2|col3|+----+----+----+----+| 1| 2| 3|null||null| 4| 5| 6|+----+----+----+----+

Changed in version 3.1.0: Added optional argument allowMissingColumns to specify whether to allow missingcolumns.

pyspark.sql.DataFrame.unpersist

DataFrame.unpersist(blocking=False)Marks the DataFrame as non-persistent, and remove all blocks for it from memory and disk.

New in version 1.3.0.

Notes

blocking default has changed to False to match Scala in 2.0.

pyspark.sql.DataFrame.where

DataFrame.where(condition)where() is an alias for filter().

New in version 1.3.

pyspark.sql.DataFrame.withColumn

DataFrame.withColumn(colName, col)Returns a new DataFrame by adding a column or replacing the existing column that has the same name.

The column expression must be an expression over this DataFrame; attempting to add a column from someother DataFrame will raise an error.

New in version 1.3.0.

Parameters

colName [str] string, name of the new column.

col [Column] a Column expression for the new column.

110 Chapter 3. API Reference

Page 115: pyspark Documentation

pyspark Documentation, Release master

Notes

This method introduces a projection internally. Therefore, calling it multiple times, for instance, via loops inorder to add multiple columns can generate big plans which can cause performance issues and even StackOver-flowException. To avoid this, use select() with the multiple columns at once.

Examples

>>> df.withColumn('age2', df.age + 2).collect()[Row(age=2, name='Alice', age2=4), Row(age=5, name='Bob', age2=7)]

pyspark.sql.DataFrame.withColumnRenamed

DataFrame.withColumnRenamed(existing, new)Returns a new DataFrame by renaming an existing column. This is a no-op if schema doesn’t contain thegiven column name.

New in version 1.3.0.

Parameters

existing [str] string, name of the existing column to rename.

new [str] string, new name of the column.

Examples

>>> df.withColumnRenamed('age', 'age2').collect()[Row(age2=2, name='Alice'), Row(age2=5, name='Bob')]

pyspark.sql.DataFrame.withWatermark

DataFrame.withWatermark(eventTime, delayThreshold)Defines an event time watermark for this DataFrame. A watermark tracks a point in time before which weassume no more late data is going to arrive.

Spark will use this watermark for several purposes:

• To know when a given time window aggregation can be finalized and thus can be emitted when usingoutput modes that do not allow updates.

• To minimize the amount of state that we need to keep for on-going aggregations.

The current watermark is computed by looking at the MAX(eventTime) seen across all of the partitions in thequery minus a user specified delayThreshold. Due to the cost of coordinating this value across partitions, theactual watermark used is only guaranteed to be at least delayThreshold behind the actual event time. In somecases we may still process records that arrive more than delayThreshold late.

New in version 2.1.0.

Parameters

eventTime [str or Column] the name of the column that contains the event time of the row.

delayThreshold [str] the minimum delay to wait to data to arrive late, relative to the latestrecord that has been processed in the form of an interval (e.g. “1 minute” or “5 hours”).

3.1. Spark SQL 111

Page 116: pyspark Documentation

pyspark Documentation, Release master

Notes

This API is evolving.

>>> from pyspark.sql.functions import timestamp_seconds>>> sdf.select(... 'name',... timestamp_seconds(sdf.time).alias('time')).withWatermark('time', '10→˓minutes')DataFrame[name: string, time: timestamp]

pyspark.sql.DataFrame.write

property DataFrame.writeInterface for saving the content of the non-streaming DataFrame out into external storage.

New in version 1.4.0.

Returns

DataFrameWriter

pyspark.sql.DataFrame.writeStream

property DataFrame.writeStreamInterface for saving the content of the streaming DataFrame out into external storage.

New in version 2.0.0.

Returns

DataStreamWriter

Notes

This API is evolving.

pyspark.sql.DataFrame.writeTo

DataFrame.writeTo(table)Create a write configuration builder for v2 sources.

This builder is used to configure and execute write operations.

For example, to append or create or replace existing tables.

New in version 3.1.0.

112 Chapter 3. API Reference

Page 117: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.writeTo("catalog.db.table").append()>>> df.writeTo(... "catalog.db.table"... ).partitionedBy("col").createOrReplace()

pyspark.sql.DataFrameNaFunctions.drop

DataFrameNaFunctions.drop(how='any', thresh=None, subset=None)Returns a new DataFrame omitting rows with null values. DataFrame.dropna() andDataFrameNaFunctions.drop() are aliases of each other.

New in version 1.3.1.

Parameters

how [str, optional] ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a rowonly if all its values are null.

thresh: int, optional default None If specified, drop rows that have less than thresh non-nullvalues. This overwrites the how parameter.

subset [str, tuple or list, optional] optional list of column names to consider.

Examples

>>> df4.na.drop().show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|+---+------+-----+

pyspark.sql.DataFrameNaFunctions.fill

DataFrameNaFunctions.fill(value, subset=None)Replace null values, alias for na.fill(). DataFrame.fillna() and DataFrameNaFunctions.fill() are aliases of each other.

New in version 1.3.1.

Parameters

value [int, float, string, bool or dict] Value to replace null values with. If the value is a dict, thensubset is ignored and value must be a mapping from column name (string) to replacementvalue. The replacement value must be an int, float, boolean, or string.

subset [str, tuple or list, optional] optional list of column names to consider. Columns specifiedin subset that do not have matching data type are ignored. For example, if value is a string,and subset contains a non-string column, then the non-string column is simply ignored.

3.1. Spark SQL 113

Page 118: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df4.na.fill(50).show()+---+------+-----+|age|height| name|+---+------+-----+| 10| 80|Alice|| 5| 50| Bob|| 50| 50| Tom|| 50| 50| null|+---+------+-----+

>>> df5.na.fill(False).show()+----+-------+-----+| age| name| spy|+----+-------+-----+| 10| Alice|false|| 5| Bob|false||null|Mallory| true|+----+-------+-----+

>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()+---+------+-------+|age|height| name|+---+------+-------+| 10| 80| Alice|| 5| null| Bob|| 50| null| Tom|| 50| null|unknown|+---+------+-------+

pyspark.sql.DataFrameNaFunctions.replace

DataFrameNaFunctions.replace(to_replace, value=<no value>, subset=None)Returns a new DataFrame replacing a value with another value. DataFrame.replace() andDataFrameNaFunctions.replace() are aliases of each other. Values to_replace and value must havethe same type and can only be numerics, booleans, or strings. Value can have None. When replacing, the newvalue will be cast to the type of the existing column. For numeric replacements all values to be replaced shouldhave unique floating point representation. In case of conflicts (for example with {42: -1, 42.0: 1}) and arbitraryreplacement will be used.

New in version 1.4.0.

Parameters

to_replace [bool, int, float, string, list or dict] Value to be replaced. If the value is a dict, thenvalue is ignored or can be omitted, and to_replace must be a mapping between a value anda replacement.

value [bool, int, float, string or None, optional] The replacement value must be a bool, int, float,string or None. If value is a list, value should be of the same length and type as to_replace.If value is a scalar and to_replace is a sequence, then value is used as a replacement for eachitem in to_replace.

subset [list, optional] optional list of column names to consider. Columns specified in subsetthat do not have matching data type are ignored. For example, if value is a string, and subsetcontains a non-string column, then the non-string column is simply ignored.

114 Chapter 3. API Reference

Page 119: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df4.na.replace(10, 20).show()+----+------+-----+| age|height| name|+----+------+-----+| 20| 80|Alice|| 5| null| Bob||null| null| Tom||null| null| null|+----+------+-----+

>>> df4.na.replace('Alice', None).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+

>>> df4.na.replace({'Alice': None}).show()+----+------+----+| age|height|name|+----+------+----+| 10| 80|null|| 5| null| Bob||null| null| Tom||null| null|null|+----+------+----+

>>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()+----+------+----+| age|height|name|+----+------+----+| 10| 80| A|| 5| null| B||null| null| Tom||null| null|null|+----+------+----+

pyspark.sql.DataFrameStatFunctions.approxQuantile

DataFrameStatFunctions.approxQuantile(col, probabilities, relativeError)Calculates the approximate quantiles of numerical columns of a DataFrame.

The result of this algorithm has the following deterministic bound: If the DataFrame has N elements andif we request the quantile at probability p up to error err, then the algorithm will return a sample x from theDataFrame so that the exact rank of x is close to (p * N). More precisely,

floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).

This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). Thealgorithm was first present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation ofQuantile Summaries]] by Greenwald and Khanna.

3.1. Spark SQL 115

Page 120: pyspark Documentation

pyspark Documentation, Release master

Note that null values will be ignored in numerical columns before calculation. For columns only containing nullvalues, an empty list is returned.

New in version 2.0.0.

Parameters

col: str, tuple or list Can be a single column name, or a list of names for multiple columns.

Changed in version 2.2: Added support for multiple columns.

probabilities [list or tuple] a list of quantile probabilities Each number must belong to [0, 1].For example 0 is the minimum, 0.5 is the median, 1 is the maximum.

relativeError [float] The relative target precision to achieve (>= 0). If set to zero, the exactquantiles are computed, which could be very expensive. Note that values greater than 1 areaccepted but give the same result as 1.

Returns

list the approximate quantiles at the given probabilities. If the input col is a string, the output isa list of floats. If the input col is a list or tuple of strings, the output is also a list, but eachelement in it is a list of floats, i.e., the output is a list of list of floats.

pyspark.sql.DataFrameStatFunctions.corr

DataFrameStatFunctions.corr(col1, col2, method=None)Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports thePearson Correlation Coefficient. DataFrame.corr() and DataFrameStatFunctions.corr() arealiases of each other.

New in version 1.4.0.

Parameters

col1 [str] The name of the first column

col2 [str] The name of the second column

method [str, optional] The correlation method. Currently only supports “pearson”

pyspark.sql.DataFrameStatFunctions.cov

DataFrameStatFunctions.cov(col1, col2)Calculate the sample covariance for the given columns, specified by their names, as a double value.DataFrame.cov() and DataFrameStatFunctions.cov() are aliases.

New in version 1.4.0.

Parameters

col1 [str] The name of the first column

col2 [str] The name of the second column

116 Chapter 3. API Reference

Page 121: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrameStatFunctions.crosstab

DataFrameStatFunctions.crosstab(col1, col2)Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number ofdistinct values for each column should be less than 1e4. At most 1e6 non-zero pair frequencies will be returned.The first column of each row will be the distinct values of col1 and the column names will be the distinct valuesof col2. The name of the first column will be $col1_$col2. Pairs that have no occurrences will have zero as theircounts. DataFrame.crosstab() and DataFrameStatFunctions.crosstab() are aliases.

New in version 1.4.0.

Parameters

col1 [str] The name of the first column. Distinct items will make the first item of each row.

col2 [str] The name of the second column. Distinct items will make the column names of theDataFrame.

pyspark.sql.DataFrameStatFunctions.freqItems

DataFrameStatFunctions.freqItems(cols, support=None)Finding frequent items for columns, possibly with false positives. Using the frequent element count algo-rithm described in “https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou”.DataFrame.freqItems() and DataFrameStatFunctions.freqItems() are aliases.

New in version 1.4.0.

Parameters

cols [list or tuple] Names of the columns to calculate frequent items for as a list or tuple ofstrings.

support [float, optional] The frequency with which to consider an item ‘frequent’. Default is1%. The support must be greater than 1e-4.

Notes

This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibilityof the schema of the resulting DataFrame.

pyspark.sql.DataFrameStatFunctions.sampleBy

DataFrameStatFunctions.sampleBy(col, fractions, seed=None)Returns a stratified sample without replacement based on the fraction given on each stratum.

New in version 1.5.0.

Parameters

col [Column or str] column that defines strata

Changed in version 3.0: Added sampling by a column of Column

fractions [dict] sampling fraction for each stratum. If a stratum is not specified, we treat itsfraction as zero.

seed [int, optional] random seed

Returns

3.1. Spark SQL 117

Page 122: pyspark Documentation

pyspark Documentation, Release master

a new DataFrame that represents the stratified sample

Examples

>>> from pyspark.sql.functions import col>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)>>> sampled.groupBy("key").count().orderBy("key").show()+---+-----+|key|count|+---+-----+| 0| 3|| 1| 6|+---+-----+>>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()33

3.1.5 Column APIs

Column.alias(*alias, **kwargs) Returns this column aliased with a new name or names(in the case of expressions that return more than onecolumn, such as explode).

Column.asc() Returns a sort expression based on ascending order ofthe column.

Column.asc_nulls_first() Returns a sort expression based on ascending order ofthe column, and null values return before non-null val-ues.

Column.asc_nulls_last() Returns a sort expression based on ascending order ofthe column, and null values appear after non-null values.

Column.astype(dataType) astype() is an alias for cast().Column.between(lowerBound, upperBound) A boolean expression that is evaluated to true if the

value of this expression is between the given columns.Column.bitwiseAND(other) Compute bitwise AND of this expression with another

expression.Column.bitwiseOR(other) Compute bitwise OR of this expression with another ex-

pression.Column.bitwiseXOR(other) Compute bitwise XOR of this expression with another

expression.Column.cast(dataType) Convert the column into type dataType.Column.contains(other) Contains the other element.Column.desc() Returns a sort expression based on the descending order

of the column.Column.desc_nulls_first() Returns a sort expression based on the descending order

of the column, and null values appear before non-nullvalues.

Column.desc_nulls_last() Returns a sort expression based on the descending or-der of the column, and null values appear after non-nullvalues.

Column.dropFields(*fieldNames) An expression that drops fields in StructType byname.

continues on next page

118 Chapter 3. API Reference

Page 123: pyspark Documentation

pyspark Documentation, Release master

Table 17 – continued from previous pageColumn.endswith(other) String ends with.Column.eqNullSafe(other) Equality test that is safe for null values.Column.getField(name) An expression that gets a field by name in a StructField.Column.getItem(key) An expression that gets an item at position ordinal

out of a list, or gets an item by key out of a dict.Column.isNotNull() True if the current expression is NOT null.Column.isNull() True if the current expression is null.Column.isin(*cols) A boolean expression that is evaluated to true if the

value of this expression is contained by the evaluatedvalues of the arguments.

Column.like(other) SQL like expression.Column.name(*alias, **kwargs) name() is an alias for alias().Column.otherwise(value) Evaluates a list of conditions and returns one of multiple

possible result expressions.Column.over(window) Define a windowing column.Column.rlike(other) SQL RLIKE expression (LIKE with Regex).Column.startswith(other) String starts with.Column.substr(startPos, length) Return a Column which is a substring of the column.Column.when(condition, value) Evaluates a list of conditions and returns one of multiple

possible result expressions.Column.withField(fieldName, col) An expression that adds/replaces a field in

StructType by name.

pyspark.sql.Column.alias

Column.alias(*alias, **kwargs)Returns this column aliased with a new name or names (in the case of expressions that return more than onecolumn, such as explode).

New in version 1.3.0.

Parameters

alias [str] desired column names (collects all positional arguments passed)

Other Parameters

metadata: dict a dict of information to be stored in metadata attribute of the correspondingStructField (optional, keyword only argument)

Changed in version 2.2.0: Added optional metadata argument.

Examples

>>> df.select(df.age.alias("age2")).collect()[Row(age2=2), Row(age2=5)]>>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata[→˓'max']99

3.1. Spark SQL 119

Page 124: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.Column.asc

Column.asc()Returns a sort expression based on ascending order of the column.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])>>> df.select(df.name).orderBy(df.name.asc()).collect()[Row(name='Alice'), Row(name='Tom')]

pyspark.sql.Column.asc_nulls_first

Column.asc_nulls_first()Returns a sort expression based on ascending order of the column, and null values return before non-null values.

New in version 2.4.0.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()[Row(name=None), Row(name='Alice'), Row(name='Tom')]

pyspark.sql.Column.asc_nulls_last

Column.asc_nulls_last()Returns a sort expression based on ascending order of the column, and null values appear after non-null values.

New in version 2.4.0.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()[Row(name='Alice'), Row(name='Tom'), Row(name=None)]

120 Chapter 3. API Reference

Page 125: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.Column.astype

Column.astype(dataType)astype() is an alias for cast().

New in version 1.4.

pyspark.sql.Column.between

Column.between(lowerBound, upperBound)A boolean expression that is evaluated to true if the value of this expression is between the given columns.

New in version 1.3.0.

Examples

>>> df.select(df.name, df.age.between(2, 4)).show()+-----+---------------------------+| name|((age >= 2) AND (age <= 4))|+-----+---------------------------+|Alice| true|| Bob| false|+-----+---------------------------+

pyspark.sql.Column.bitwiseAND

Column.bitwiseAND(other)Compute bitwise AND of this expression with another expression.

Parameters

other a value or Column to calculate bitwise and(&) with this Column.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(a=170, b=75)])>>> df.select(df.a.bitwiseAND(df.b)).collect()[Row((a & b)=10)]

pyspark.sql.Column.bitwiseOR

Column.bitwiseOR(other)Compute bitwise OR of this expression with another expression.

Parameters

other a value or Column to calculate bitwise or(|) with this Column.

3.1. Spark SQL 121

Page 126: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(a=170, b=75)])>>> df.select(df.a.bitwiseOR(df.b)).collect()[Row((a | b)=235)]

pyspark.sql.Column.bitwiseXOR

Column.bitwiseXOR(other)Compute bitwise XOR of this expression with another expression.

Parameters

other a value or Column to calculate bitwise xor(^) with this Column.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(a=170, b=75)])>>> df.select(df.a.bitwiseXOR(df.b)).collect()[Row((a ^ b)=225)]

pyspark.sql.Column.cast

Column.cast(dataType)Convert the column into type dataType.

New in version 1.3.0.

Examples

>>> df.select(df.age.cast("string").alias('ages')).collect()[Row(ages='2'), Row(ages='5')]>>> df.select(df.age.cast(StringType()).alias('ages')).collect()[Row(ages='2'), Row(ages='5')]

pyspark.sql.Column.contains

Column.contains(other)Contains the other element. Returns a boolean Column based on a string match.

Parameters

other string in line. A value as a literal or a Column.

122 Chapter 3. API Reference

Page 127: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.filter(df.name.contains('o')).collect()[Row(age=5, name='Bob')]

pyspark.sql.Column.desc

Column.desc()Returns a sort expression based on the descending order of the column.

New in version 2.4.0.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])>>> df.select(df.name).orderBy(df.name.desc()).collect()[Row(name='Tom'), Row(name='Alice')]

pyspark.sql.Column.desc_nulls_first

Column.desc_nulls_first()Returns a sort expression based on the descending order of the column, and null values appear before non-nullvalues.

New in version 2.4.0.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()[Row(name=None), Row(name='Tom'), Row(name='Alice')]

pyspark.sql.Column.desc_nulls_last

Column.desc_nulls_last()Returns a sort expression based on the descending order of the column, and null values appear after non-nullvalues.

New in version 2.4.0.

3.1. Spark SQL 123

Page 128: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name→˓", "height"])>>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()[Row(name='Tom'), Row(name='Alice'), Row(name=None)]

pyspark.sql.Column.dropFields

Column.dropFields(*fieldNames)An expression that drops fields in StructType by name.

New in version 3.1.0.

Examples

>>> from pyspark.sql import Row>>> from pyspark.sql.functions import col, lit>>> df = spark.createDataFrame([... Row(a=Row(b=1, c=2, d=3, e=Row(f=4, g=5, h=6)))])>>> df.withColumn('a', df['a'].dropFields('b')).show()+-----------------+| a|+-----------------+|{2, 3, {4, 5, 6}}|+-----------------+

>>> df.withColumn('a', df['a'].dropFields('b', 'c')).show()+--------------+| a|+--------------+|{3, {4, 5, 6}}|+--------------+

This method supports dropping multiple nested fields directly e.g.

>>> df.withColumn("a", col("a").dropFields("e.g", "e.h")).show()+--------------+| a|+--------------+|{1, 2, 3, {4}}|+--------------+

However, if you are going to add/replace multiple nested fields, it is preferred to extract out the nested structbefore adding/replacing multiple fields e.g.

>>> df.select(col("a").withField(... "e", col("a.e").dropFields("g", "h")).alias("a")... ).show()+--------------+| a|+--------------+|{1, 2, 3, {4}}|+--------------+

124 Chapter 3. API Reference

Page 129: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.Column.endswith

Column.endswith(other)String ends with. Returns a boolean Column based on a string match.

Parameters

other [Column or str] string at end of line (do not use a regex $)

Examples

>>> df.filter(df.name.endswith('ice')).collect()[Row(age=2, name='Alice')]>>> df.filter(df.name.endswith('ice$')).collect()[]

pyspark.sql.Column.eqNullSafe

Column.eqNullSafe(other)Equality test that is safe for null values.

New in version 2.3.0.

Parameters

other a value or Column

Notes

Unlike Pandas, PySpark doesn’t consider NaN values to be NULL. See the NaN Semantics for details.

Examples

>>> from pyspark.sql import Row>>> df1 = spark.createDataFrame([... Row(id=1, value='foo'),... Row(id=2, value=None)... ])>>> df1.select(... df1['value'] == 'foo',... df1['value'].eqNullSafe('foo'),... df1['value'].eqNullSafe(None)... ).show()+-------------+---------------+----------------+|(value = foo)|(value <=> foo)|(value <=> NULL)|+-------------+---------------+----------------+| true| true| false|| null| false| true|+-------------+---------------+----------------+>>> df2 = spark.createDataFrame([... Row(value = 'bar'),... Row(value = None)... ])>>> df1.join(df2, df1["value"] == df2["value"]).count()

(continues on next page)

3.1. Spark SQL 125

Page 130: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

0>>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()1>>> df2 = spark.createDataFrame([... Row(id=1, value=float('NaN')),... Row(id=2, value=42.0),... Row(id=3, value=None)... ])>>> df2.select(... df2['value'].eqNullSafe(None),... df2['value'].eqNullSafe(float('NaN')),... df2['value'].eqNullSafe(42.0)... ).show()+----------------+---------------+----------------+|(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)|+----------------+---------------+----------------+| false| true| false|| false| false| true|| true| false| false|+----------------+---------------+----------------+

pyspark.sql.Column.getField

Column.getField(name)An expression that gets a field by name in a StructField.

New in version 1.3.0.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(r=Row(a=1, b="b"))])>>> df.select(df.r.getField("b")).show()+---+|r.b|+---+| b|+---+>>> df.select(df.r.a).show()+---+|r.a|+---+| 1|+---+

126 Chapter 3. API Reference

Page 131: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.Column.getItem

Column.getItem(key)An expression that gets an item at position ordinal out of a list, or gets an item by key out of a dict.

New in version 1.3.0.

Examples

>>> df = spark.createDataFrame([([1, 2], {"key": "value"})], ["l", "d"])>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()+----+------+|l[0]|d[key]|+----+------+| 1| value|+----+------+

pyspark.sql.Column.isNotNull

Column.isNotNull()True if the current expression is NOT null.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Alice',→˓height=None)])>>> df.filter(df.height.isNotNull()).collect()[Row(name='Tom', height=80)]

pyspark.sql.Column.isNull

Column.isNull()True if the current expression is null.

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Alice',→˓height=None)])>>> df.filter(df.height.isNull()).collect()[Row(name='Alice', height=None)]

3.1. Spark SQL 127

Page 132: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.Column.isin

Column.isin(*cols)A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated valuesof the arguments.

New in version 1.5.0.

Examples

>>> df[df.name.isin("Bob", "Mike")].collect()[Row(age=5, name='Bob')]>>> df[df.age.isin([1, 2, 3])].collect()[Row(age=2, name='Alice')]

pyspark.sql.Column.like

Column.like(other)SQL like expression. Returns a boolean Column based on a SQL LIKE match.

Parameters

other [str] a SQL LIKE pattern

See also:

pyspark.sql.Column.rlike

Examples

>>> df.filter(df.name.like('Al%')).collect()[Row(age=2, name='Alice')]

pyspark.sql.Column.name

Column.name(*alias, **kwargs)name() is an alias for alias().

New in version 2.0.

pyspark.sql.Column.otherwise

Column.otherwise(value)Evaluates a list of conditions and returns one of multiple possible result expressions. If Column.otherwise() is not invoked, None is returned for unmatched conditions.

New in version 1.4.0.

Parameters

value a literal value, or a Column expression.

See also:

128 Chapter 3. API Reference

Page 133: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.when

Examples

>>> from pyspark.sql import functions as F>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()+-----+-------------------------------------+| name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|+-----+-------------------------------------+|Alice| 0|| Bob| 1|+-----+-------------------------------------+

pyspark.sql.Column.over

Column.over(window)Define a windowing column.

New in version 1.4.0.

Parameters

window [WindowSpec]

Returns

Column

Examples

>>> from pyspark.sql import Window>>> window = Window.partitionBy("name").orderBy("age") .→˓rowsBetween(Window.unboundedPreceding, Window.currentRow)>>> from pyspark.sql.functions import rank, min>>> from pyspark.sql.functions import desc>>> df.withColumn("rank", rank().over(window)) .withColumn("min",→˓min('age').over(window)).sort(desc("age")).show()+---+-----+----+---+|age| name|rank|min|+---+-----+----+---+| 5| Bob| 1| 5|| 2|Alice| 1| 2|+---+-----+----+---+

pyspark.sql.Column.rlike

Column.rlike(other)SQL RLIKE expression (LIKE with Regex). Returns a boolean Column based on a regex match.

Parameters

other [str] an extended regex expression

3.1. Spark SQL 129

Page 134: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.filter(df.name.rlike('ice$')).collect()[Row(age=2, name='Alice')]

pyspark.sql.Column.startswith

Column.startswith(other)String starts with. Returns a boolean Column based on a string match.

Parameters

other [Column or str] string at start of line (do not use a regex ^)

Examples

>>> df.filter(df.name.startswith('Al')).collect()[Row(age=2, name='Alice')]>>> df.filter(df.name.startswith('^Al')).collect()[]

pyspark.sql.Column.substr

Column.substr(startPos, length)Return a Column which is a substring of the column.

New in version 1.3.0.

Parameters

startPos [Column or int] start position

length [Column or int] length of the substring

Examples

>>> df.select(df.name.substr(1, 3).alias("col")).collect()[Row(col='Ali'), Row(col='Bob')]

pyspark.sql.Column.when

Column.when(condition, value)Evaluates a list of conditions and returns one of multiple possible result expressions. If Column.otherwise() is not invoked, None is returned for unmatched conditions.

New in version 1.4.0.

Parameters

condition [Column] a boolean Column expression.

value a literal value, or a Column expression.

See also:

130 Chapter 3. API Reference

Page 135: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.when

Examples

>>> from pyspark.sql import functions as F>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).→˓show()+-----+------------------------------------------------------------+| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|+-----+------------------------------------------------------------+|Alice| -1|| Bob| 1|+-----+------------------------------------------------------------+

pyspark.sql.Column.withField

Column.withField(fieldName, col)An expression that adds/replaces a field in StructType by name.

New in version 3.1.0.

Examples

>>> from pyspark.sql import Row>>> from pyspark.sql.functions import lit>>> df = spark.createDataFrame([Row(a=Row(b=1, c=2))])>>> df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show()+---+| b|+---+| 3|+---+>>> df.withColumn('a', df['a'].withField('d', lit(4))).select('a.d').show()+---+| d|+---+| 4|+---+

3.1.6 Data Types

ArrayType(elementType[, containsNull]) Array data type.BinaryType() Binary (byte array) data type.BooleanType() Boolean data type.ByteType() Byte data type, i.e.DataType() Base class for data types.DateType() Date (datetime.date) data type.DecimalType([precision, scale]) Decimal (decimal.Decimal) data type.DoubleType() Double data type, representing double precision floats.FloatType() Float data type, representing single precision floats.

continues on next page

3.1. Spark SQL 131

Page 136: pyspark Documentation

pyspark Documentation, Release master

Table 18 – continued from previous pageIntegerType() Int data type, i.e.LongType() Long data type, i.e.MapType(keyType, valueType[, valueContainsNull]) Map data type.NullType() Null type.ShortType() Short data type, i.e.StringType() String data type.StructField(name, dataType[, nullable, metadata]) A field in StructType.StructType([fields]) Struct type, consisting of a list of StructField.TimestampType() Timestamp (datetime.datetime) data type.

ArrayType

class pyspark.sql.types.ArrayType(elementType, containsNull=True)Array data type.

Parameters

elementType [DataType] DataType of each element in the array.

containsNull [bool, optional] whether the array can contain null (None) values.

Examples

>>> ArrayType(StringType()) == ArrayType(StringType(), True)True>>> ArrayType(StringType(), False) == ArrayType(StringType())False

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

132 Chapter 3. API Reference

Page 137: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

classmethod fromJson(json)

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

BinaryType

class pyspark.sql.types.BinaryTypeBinary (byte array) data type.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

3.1. Spark SQL 133

Page 138: pyspark Documentation

pyspark Documentation, Release master

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

BooleanType

class pyspark.sql.types.BooleanTypeBoolean data type.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

ByteType

class pyspark.sql.types.ByteTypeByte data type, i.e. a signed integer in a single byte.

134 Chapter 3. API Reference

Page 139: pyspark Documentation

pyspark Documentation, Release master

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

DataType

class pyspark.sql.types.DataTypeBase class for data types.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

3.1. Spark SQL 135

Page 140: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

DateType

class pyspark.sql.types.DateTypeDate (datetime.date) data type.

Methods

fromInternal(v) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(d) Converts a Python object into an internal SQL object.typeName()

Attributes

EPOCH_ORDINAL

Methods Documentation

fromInternal(v)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

136 Chapter 3. API Reference

Page 141: pyspark Documentation

pyspark Documentation, Release master

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(d)Converts a Python object into an internal SQL object.

classmethod typeName()

Attributes Documentation

EPOCH_ORDINAL = 719163

DecimalType

class pyspark.sql.types.DecimalType(precision=10, scale=0)Decimal (decimal.Decimal) data type.

The DecimalType must have fixed precision (the maximum total number of digits) and scale (the number ofdigits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99].

The precision can be up to 38, the scale must be less or equal to precision.

When creating a DecimalType, the default precision and scale is (10, 0). When inferring schema from deci-mal.Decimal objects, it will be DecimalType(38, 18).

Parameters

precision [int, optional] the maximum (i.e. total) number of digits (default: 10)

scale [int, optional] the number of digits on right side of dot. (default: 0)

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

3.1. Spark SQL 137

Page 142: pyspark Documentation

pyspark Documentation, Release master

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

DoubleType

class pyspark.sql.types.DoubleTypeDouble data type, representing double precision floats.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

138 Chapter 3. API Reference

Page 143: pyspark Documentation

pyspark Documentation, Release master

FloatType

class pyspark.sql.types.FloatTypeFloat data type, representing single precision floats.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

IntegerType

class pyspark.sql.types.IntegerTypeInt data type, i.e. a signed 32-bit integer.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.continues on next page

3.1. Spark SQL 139

Page 144: pyspark Documentation

pyspark Documentation, Release master

Table 29 – continued from previous pagesimpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

LongType

class pyspark.sql.types.LongTypeLong data type, i.e. a signed 64-bit integer.

If the values are beyond the range of [-9223372036854775808, 9223372036854775807], please useDecimalType.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

140 Chapter 3. API Reference

Page 145: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

MapType

class pyspark.sql.types.MapType(keyType, valueType, valueContainsNull=True)Map data type.

Parameters

keyType [DataType] DataType of the keys in the map.

valueType [DataType] DataType of the values in the map.

valueContainsNull [bool, optional] indicates whether values can contain null (None) values.

Notes

Keys in a map data type are not allowed to be null (None).

Examples

>>> (MapType(StringType(), IntegerType())... == MapType(StringType(), IntegerType(), True))True>>> (MapType(StringType(), IntegerType(), False)... == MapType(StringType(), FloatType()))False

3.1. Spark SQL 141

Page 146: pyspark Documentation

pyspark Documentation, Release master

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

classmethod fromJson(json)

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

NullType

class pyspark.sql.types.NullTypeNull type.

The data type representing None, used for the types that cannot be inferred.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.

continues on next page

142 Chapter 3. API Reference

Page 147: pyspark Documentation

pyspark Documentation, Release master

Table 32 – continued from previous pagetypeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

ShortType

class pyspark.sql.types.ShortTypeShort data type, i.e. a signed 16-bit integer.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

3.1. Spark SQL 143

Page 148: pyspark Documentation

pyspark Documentation, Release master

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

StringType

class pyspark.sql.types.StringTypeString data type.

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

StructField

class pyspark.sql.types.StructField(name, dataType, nullable=True, metadata=None)A field in StructType.

Parameters

name [str] name of the field.

dataType [DataType] DataType of the field.

nullable [bool] whether the field can be null (None) or not.

144 Chapter 3. API Reference

Page 149: pyspark Documentation

pyspark Documentation, Release master

metadata [dict] a dict from string to simple type that can be toInternald to JSON automatically

Examples

>>> (StructField("f1", StringType(), True)... == StructField("f1", StringType(), True))True>>> (StructField("f1", StringType(), True)... == StructField("f2", StringType(), True))False

Methods

fromInternal(obj) Converts an internal SQL object into a native Pythonobject.

fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(obj)Converts an internal SQL object into a native Python object.

classmethod fromJson(json)

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

typeName()

3.1. Spark SQL 145

Page 150: pyspark Documentation

pyspark Documentation, Release master

StructType

class pyspark.sql.types.StructType(fields=None)Struct type, consisting of a list of StructField.

This is the data type representing a Row.

Iterating a StructType will iterate over its StructFields. A contained StructField can be accessedby its name or position.

Examples

>>> struct1 = StructType([StructField("f1", StringType(), True)])>>> struct1["f1"]StructField(f1,StringType,true)>>> struct1[0]StructField(f1,StringType,true)

>>> struct1 = StructType([StructField("f1", StringType(), True)])>>> struct2 = StructType([StructField("f1", StringType(), True)])>>> struct1 == struct2True>>> struct1 = StructType([StructField("f1", StringType(), True)])>>> struct2 = StructType([StructField("f1", StringType(), True),... StructField("f2", IntegerType(), False)])>>> struct1 == struct2False

Methods

add(field[, data_type, nullable, metadata]) Construct a StructType by adding new elements to it,to define the schema.

fieldNames() Returns all field names in a list.fromInternal(obj) Converts an internal SQL object into a native Python

object.fromJson(json)json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(obj) Converts a Python object into an internal SQL object.typeName()

146 Chapter 3. API Reference

Page 151: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

add(field, data_type=None, nullable=True, metadata=None)Construct a StructType by adding new elements to it, to define the schema. The method accepts either:

a) A single parameter which is a StructField object.

b) Between 2 and 4 parameters as (name, data_type, nullable (optional), metadata(optional). Thedata_type parameter may be either a String or a DataType object.

Parameters

field [str or StructField] Either the name of the field or a StructField object

data_type [DataType, optional] If present, the DataType of the StructField to create

nullable [bool, optional] Whether the field to add should be nullable (default True)

metadata [dict, optional] Any additional metadata (default None)

Returns

StructType

Examples

>>> struct1 = StructType().add("f1", StringType(), True).add("f2",→˓StringType(), True, None)>>> struct2 = StructType([StructField("f1", StringType(), True), \... StructField("f2", StringType(), True, None)])>>> struct1 == struct2True>>> struct1 = StructType().add(StructField("f1", StringType(), True))>>> struct2 = StructType([StructField("f1", StringType(), True)])>>> struct1 == struct2True>>> struct1 = StructType().add("f1", "string", True)>>> struct2 = StructType([StructField("f1", StringType(), True)])>>> struct1 == struct2True

fieldNames()Returns all field names in a list.

Examples

>>> struct = StructType([StructField("f1", StringType(), True)])>>> struct.fieldNames()['f1']

fromInternal(obj)Converts an internal SQL object into a native Python object.

classmethod fromJson(json)

json()

jsonValue()

3.1. Spark SQL 147

Page 152: pyspark Documentation

pyspark Documentation, Release master

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(obj)Converts a Python object into an internal SQL object.

classmethod typeName()

TimestampType

class pyspark.sql.types.TimestampTypeTimestamp (datetime.datetime) data type.

Methods

fromInternal(ts) Converts an internal SQL object into a native Pythonobject.

json()jsonValue()needConversion() Does this type needs conversion between Python ob-

ject and internal SQL object.simpleString()toInternal(dt) Converts a Python object into an internal SQL object.typeName()

Methods Documentation

fromInternal(ts)Converts an internal SQL object into a native Python object.

json()

jsonValue()

needConversion()Does this type needs conversion between Python object and internal SQL object.

This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.

simpleString()

toInternal(dt)Converts a Python object into an internal SQL object.

classmethod typeName()

148 Chapter 3. API Reference

Page 153: pyspark Documentation

pyspark Documentation, Release master

3.1.7 Row

Row.asDict([recursive]) Return as a dict

pyspark.sql.Row.asDict

Row.asDict(recursive=False)Return as a dict

Parameters

recursive [bool, optional] turns the nested Rows to dict (default: False).

Notes

If a row contains duplicate field names, e.g., the rows of a join between two DataFrame that both have thefields of same names, one of the duplicate fields will be selected by asDict. __getitem__ will also returnone of the duplicate fields, however returned value might be different to asDict.

Examples

>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}True>>> row = Row(key=1, value=Row(name='a', age=2))>>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}True>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}True

3.1.8 Functions

abs(col) Computes the absolute value.acos(col) New in version 1.4.0.

acosh(col) Computes inverse hyperbolic cosine of the input col-umn.

add_months(start, months) Returns the date that is months months after startaggregate(col, zero, merge[, finish]) Applies a binary operator to an initial state and all ele-

ments in the array, and reduces this to a single state.approxCountDistinct(col[, rsd]) Deprecated since version 2.1.0.

approx_count_distinct(col[, rsd]) Aggregate function: returns a new Column for approx-imate distinct count of column col.

array(*cols) Creates a new array column.array_contains(col, value) Collection function: returns null if the array is null, true

if the array contains the given value, and false otherwise.array_distinct(col) Collection function: removes duplicate values from the

array.continues on next page

3.1. Spark SQL 149

Page 154: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pagearray_except(col1, col2) Collection function: returns an array of the elements in

col1 but not in col2, without duplicates.array_intersect(col1, col2) Collection function: returns an array of the elements in

the intersection of col1 and col2, without duplicates.array_join(col, delimiter[, null_replacement]) Concatenates the elements of column using the delim-

iter.array_max(col) Collection function: returns the maximum value of the

array.array_min(col) Collection function: returns the minimum value of the

array.array_position(col, value) Collection function: Locates the position of the first oc-

currence of the given value in the given array.array_remove(col, element) Collection function: Remove all elements that equal to

element from the given array.array_repeat(col, count) Collection function: creates an array containing a col-

umn repeated count times.array_sort(col) Collection function: sorts the input array in ascending

order.array_union(col1, col2) Collection function: returns an array of the elements in

the union of col1 and col2, without duplicates.arrays_overlap(a1, a2) Collection function: returns true if the arrays contain

any common non-null element; if not, returns null ifboth the arrays are non-empty and any of them containsa null element; returns false otherwise.

arrays_zip(*cols) Collection function: Returns a merged array of structs inwhich the N-th struct contains all N-th values of inputarrays.

asc(col) Returns a sort expression based on the ascending orderof the given column name.

asc_nulls_first(col) Returns a sort expression based on the ascending orderof the given column name, and null values return beforenon-null values.

asc_nulls_last(col) Returns a sort expression based on the ascending orderof the given column name, and null values appear afternon-null values.

ascii(col) Computes the numeric value of the first character of thestring column.

asin(col) New in version 1.3.0.

asinh(col) Computes inverse hyperbolic sine of the input column.assert_true(col[, errMsg]) Returns null if the input column is true; throws an ex-

ception with the provided error message otherwise.atan(col) New in version 1.4.0.

atanh(col) Computes inverse hyperbolic tangent of the input col-umn.

atan2(col1, col2) New in version 1.4.0.

avg(col) Aggregate function: returns the average of the values ina group.

continues on next page

150 Chapter 3. API Reference

Page 155: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pagebase64(col) Computes the BASE64 encoding of a binary column

and returns it as a string column.bin(col) Returns the string representation of the binary value of

the given column.bitwiseNOT(col) Computes bitwise not.broadcast(df) Marks a DataFrame as small enough for use in broadcast

joins.bround(col[, scale]) Round the given value to scale decimal places using

HALF_EVEN rounding mode if scale >= 0 or at inte-gral part when scale < 0.

bucket(numBuckets, col) Partition transform function: A transform for any typethat partitions by a hash of the input column.

cbrt(col) Computes the cube-root of the given value.ceil(col) Computes the ceiling of the given value.coalesce(*cols) Returns the first column that is not null.col(col) Returns a Column based on the given column name.’collect_list(col) Aggregate function: returns a list of objects with dupli-

cates.collect_set(col) Aggregate function: returns a set of objects with dupli-

cate elements eliminated.column(col) Returns a Column based on the given column name.’concat(*cols) Concatenates multiple input columns together into a sin-

gle column.concat_ws(sep, *cols) Concatenates multiple input string columns together

into a single string column, using the given separator.conv(col, fromBase, toBase) Convert a number in a string column from one base to

another.corr(col1, col2) Returns a new Column for the Pearson Correlation Co-

efficient for col1 and col2.cos(col) New in version 1.4.0.

cosh(col) New in version 1.4.0.

count(col) Aggregate function: returns the number of items in agroup.

countDistinct(col, *cols) Returns a new Column for distinct count of col orcols.

covar_pop(col1, col2) Returns a new Column for the population covarianceof col1 and col2.

covar_samp(col1, col2) Returns a new Column for the sample covariance ofcol1 and col2.

crc32(col) Calculates the cyclic redundancy check value (CRC32)of a binary column and returns the value as a bigint.

create_map(*cols) Creates a new map column.cume_dist() Window function: returns the cumulative distribution of

values within a window partition, i.e.current_date() Returns the current date at the start of query evaluation

as a DateType column.current_timestamp() Returns the current timestamp at the start of query eval-

uation as a TimestampType column.date_add(start, days) Returns the date that is days days after start

continues on next page

3.1. Spark SQL 151

Page 156: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pagedate_format(date, format) Converts a date/timestamp/string to a value of string in

the format specified by the date format given by the sec-ond argument.

date_sub(start, days) Returns the date that is days days before startdate_trunc(format, timestamp) Returns timestamp truncated to the unit specified by the

format.datediff(end, start) Returns the number of days from start to end.dayofmonth(col) Extract the day of the month of a given date as integer.dayofweek(col) Extract the day of the week of a given date as integer.dayofyear(col) Extract the day of the year of a given date as integer.days(col) Partition transform function: A transform for times-

tamps and dates to partition data into days.decode(col, charset) Computes the first argument into a string from a bi-

nary using the provided character set (one of ‘US-ASCII’, ‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).

degrees(col) Converts an angle measured in radians to an approxi-mately equivalent angle measured in degrees.

dense_rank() Window function: returns the rank of rows within a win-dow partition, without any gaps.

desc(col) Returns a sort expression based on the descending orderof the given column name.

desc_nulls_first(col) Returns a sort expression based on the descending orderof the given column name, and null values appear beforenon-null values.

desc_nulls_last(col) Returns a sort expression based on the descending orderof the given column name, and null values appear afternon-null values.

element_at(col, extraction) Collection function: Returns element of array at givenindex in extraction if col is array.

encode(col, charset) Computes the first argument into a binary from astring using the provided character set (one of ‘US-ASCII’, ‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).

exists(col, f) Returns whether a predicate holds for one or more ele-ments in the array.

exp(col) Computes the exponential of the given value.explode(col) Returns a new row for each element in the given array

or map.explode_outer(col) Returns a new row for each element in the given array

or map.expm1(col) Computes the exponential of the given value minus one.expr(str) Parses the expression string into the column that it rep-

resentsfactorial(col) Computes the factorial of the given value.filter(col, f) Returns an array of elements for which a predicate holds

in a given array.first(col[, ignorenulls]) Aggregate function: returns the first value in a group.flatten(col) Collection function: creates a single array from an array

of arrays.floor(col) Computes the floor of the given value.

continues on next page

152 Chapter 3. API Reference

Page 157: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pageforall(col, f) Returns whether a predicate holds for every element in

the array.format_number(col, d) Formats the number X to a format like ‘#,–#,–#.–’,

rounded to d decimal places with HALF_EVEN roundmode, and returns the result as a string.

format_string(format, *cols) Formats the arguments in printf-style and returns the re-sult as a string column.

from_csv(col, schema[, options]) Parses a column containing a CSV string to a row withthe specified schema.

from_json(col, schema[, options]) Parses a column containing a JSON string intoa MapType with StringType as keys type,StructType or ArrayType with the specifiedschema.

from_unixtime(timestamp[, format]) Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the times-tamp of that moment in the current system time zone inthe given format.

from_utc_timestamp(timestamp, tz) This is a common function for databases supportingTIMESTAMP WITHOUT TIMEZONE.

get_json_object(col, path) Extracts json object from a json string based on jsonpath specified, and returns json string of the extractedjson object.

greatest(*cols) Returns the greatest value of the list of column names,skipping null values.

grouping(col) Aggregate function: indicates whether a specified col-umn in a GROUP BY list is aggregated or not, returns 1for aggregated or 0 for not aggregated in the result set.

grouping_id(*cols) Aggregate function: returns the level of grouping,equals to

hash(*cols) Calculates the hash code of given columns, and returnsthe result as an int column.

hex(col) Computes hex value of the given column, whichcould be pyspark.sql.types.StringType,pyspark.sql.types.BinaryType, pyspark.sql.types.IntegerType or pyspark.sql.types.LongType.

hour(col) Extract the hours of a given date as integer.hours(col) Partition transform function: A transform for times-

tamps to partition data into hours.hypot(col1, col2) Computes sqrt(a^2 + b^2) without intermediate

overflow or underflow.initcap(col) Translate the first letter of each word to upper case in

the sentence.input_file_name() Creates a string column for the file name of the current

Spark task.instr(str, substr) Locate the position of the first occurrence of substr col-

umn in the given string.isnan(col) An expression that returns true iff the column is NaN.isnull(col) An expression that returns true iff the column is null.json_tuple(col, *fields) Creates a new row for a json column according to the

given field names.continues on next page

3.1. Spark SQL 153

Page 158: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pagekurtosis(col) Aggregate function: returns the kurtosis of the values in

a group.lag(col[, offset, default]) Window function: returns the value that is offset rows

before the current row, and defaultValue if there is lessthan offset rows before the current row.

last(col[, ignorenulls]) Aggregate function: returns the last value in a group.last_day(date) Returns the last day of the month which the given date

belongs to.lead(col[, offset, default]) Window function: returns the value that is offset rows

after the current row, and defaultValue if there is lessthan offset rows after the current row.

least(*cols) Returns the least value of the list of column names, skip-ping null values.

length(col) Computes the character length of string data or numberof bytes of binary data.

levenshtein(left, right) Computes the Levenshtein distance of the two givenstrings.

lit(col) Creates a Column of literal value.locate(substr, str[, pos]) Locate the position of the first occurrence of substr in a

string column, after position pos.log(arg1[, arg2]) Returns the first argument-based logarithm of the sec-

ond argument.log10(col) Computes the logarithm of the given value in Base 10.log1p(col) Computes the natural logarithm of the given value plus

one.log2(col) Returns the base-2 logarithm of the argument.lower(col) Converts a string expression to lower case.lpad(col, len, pad) Left-pad the string column to width len with pad.ltrim(col) Trim the spaces from left end for the specified string

value.map_concat(*cols) Returns the union of all the given maps.map_entries(col) Collection function: Returns an unordered array of all

entries in the given map.map_filter(col, f) Returns a map whose key-value pairs satisfy a predicate.map_from_arrays(col1, col2) Creates a new map from two arrays.map_from_entries(col) Collection function: Returns a map created from the

given array of entries.map_keys(col) Collection function: Returns an unordered array con-

taining the keys of the map.map_values(col) Collection function: Returns an unordered array con-

taining the values of the map.map_zip_with(col1, col2, f) Merge two given maps, key-wise into a single map using

a function.max(col) Aggregate function: returns the maximum value of the

expression in a group.md5(col) Calculates the MD5 digest and returns the value as a 32

character hex string.mean(col) Aggregate function: returns the average of the values in

a group.min(col) Aggregate function: returns the minimum value of the

expression in a group.continues on next page

154 Chapter 3. API Reference

Page 159: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pageminute(col) Extract the minutes of a given date as integer.monotonically_increasing_id() A column that generates monotonically increasing 64-

bit integers.month(col) Extract the month of a given date as integer.months(col) Partition transform function: A transform for times-

tamps and dates to partition data into months.months_between(date1, date2[, roundOff]) Returns number of months between dates date1 and

date2.nanvl(col1, col2) Returns col1 if it is not NaN, or col2 if col1 is NaN.next_day(date, dayOfWeek) Returns the first date which is later than the value of the

date column.nth_value(col, offset[, ignoreNulls]) Window function: returns the value that is the offsetth

row of the window frame (counting from 1), and null ifthe size of window frame is less than offset rows.

ntile(n) Window function: returns the ntile group id (from 1 ton inclusive) in an ordered window partition.

overlay(src, replace, pos[, len]) Overlay the specified portion of src with replace, start-ing from byte position pos of src and proceeding for lenbytes.

pandas_udf([f, returnType, functionType]) Creates a pandas user defined function (a.k.a.percent_rank() Window function: returns the relative rank (i.e.percentile_approx(col, percentage[, accuracy]) Returns the approximate percentile of the numeric col-

umn col which is the smallest value in the ordered colvalues (sorted from least to greatest) such that no morethan percentage of col values is less than the value orequal to that value.

posexplode(col) Returns a new row for each element with position in thegiven array or map.

posexplode_outer(col) Returns a new row for each element with position in thegiven array or map.

pow(col1, col2) Returns the value of the first argument raised to thepower of the second argument.

quarter(col) Extract the quarter of a given date as integer.radians(col) Converts an angle measured in degrees to an approxi-

mately equivalent angle measured in radians.raise_error(errMsg) Throws an exception with the provided error message.rand([seed]) Generates a random column with independent and iden-

tically distributed (i.i.d.) samples uniformly distributedin [0.0, 1.0).

randn([seed]) Generates a column with independent and identicallydistributed (i.i.d.) samples from the standard normaldistribution.

rank() Window function: returns the rank of rows within a win-dow partition.

regexp_extract(str, pattern, idx) Extract a specific group matched by a Java regex, fromthe specified string column.

regexp_replace(str, pattern, replacement) Replace all substrings of the specified string value thatmatch regexp with rep.

repeat(col, n) Repeats a string column n times, and returns it as a newstring column.

continues on next page

3.1. Spark SQL 155

Page 160: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pagereverse(col) Collection function: returns a reversed string or an array

with reverse order of elements.rint(col) Returns the double value that is closest in value to the

argument and is equal to a mathematical integer.round(col[, scale]) Round the given value to scale decimal places using

HALF_UP rounding mode if scale >= 0 or at integralpart when scale < 0.

row_number() Window function: returns a sequential number startingat 1 within a window partition.

rpad(col, len, pad) Right-pad the string column to width len with pad.rtrim(col) Trim the spaces from right end for the specified string

value.schema_of_csv(csv[, options]) Parses a CSV string and infers its schema in DDL for-

mat.schema_of_json(json[, options]) Parses a JSON string and infers its schema in DDL for-

mat.second(col) Extract the seconds of a given date as integer.sequence(start, stop[, step]) Generate a sequence of integers from start to stop, in-

crementing by step.sha1(col) Returns the hex string result of SHA-1.sha2(col, numBits) Returns the hex string result of SHA-2 family of hash

functions (SHA-224, SHA-256, SHA-384, and SHA-512).

shiftLeft(col, numBits) Shift the given value numBits left.shiftRight(col, numBits) (Signed) shift the given value numBits right.shiftRightUnsigned(col, numBits) Unsigned shift the given value numBits right.shuffle(col) Collection function: Generates a random permutation of

the given array.signum(col) Computes the signum of the given value.sin(col) New in version 1.4.0.

sinh(col) New in version 1.4.0.

size(col) Collection function: returns the length of the array ormap stored in the column.

skewness(col) Aggregate function: returns the skewness of the valuesin a group.

slice(x, start, length) Collection function: returns an array containing all theelements in x from index start (array indices start at 1,or from the end if start is negative) with the specifiedlength.

sort_array(col[, asc]) Collection function: sorts the input array in ascendingor descending order according to the natural ordering ofthe array elements.

soundex(col) Returns the SoundEx encoding for a stringspark_partition_id() A column for partition ID.split(str, pattern[, limit]) Splits str around matches of the given pattern.sqrt(col) Computes the square root of the specified float value.stddev(col) Aggregate function: alias for stddev_samp.stddev_pop(col) Aggregate function: returns population standard devia-

tion of the expression in a group.continues on next page

156 Chapter 3. API Reference

Page 161: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pagestddev_samp(col) Aggregate function: returns the unbiased sample stan-

dard deviation of the expression in a group.struct(*cols) Creates a new struct column.substring(str, pos, len) Substring starts at pos and is of length len when str is

String type or returns the slice of byte array that starts atpos in byte and is of length len when str is Binary type.

substring_index(str, delim, count) Returns the substring from string str before count occur-rences of the delimiter delim.

sum(col) Aggregate function: returns the sum of all values in theexpression.

sumDistinct(col) Aggregate function: returns the sum of distinct valuesin the expression.

tan(col) New in version 1.4.0.

tanh(col) New in version 1.4.0.

timestamp_seconds(col) New in version 3.1.0.

toDegrees(col) Deprecated since version 2.1.0.

toRadians(col) Deprecated since version 2.1.0.

to_csv(col[, options]) Converts a column containing a StructType into aCSV string.

to_date(col[, format]) Converts a Column into pyspark.sql.types.DateType using the optionally specified format.

to_json(col[, options]) Converts a column containing a StructType,ArrayType or a MapType into a JSON string.

to_timestamp(col[, format]) Converts a Column into pyspark.sql.types.TimestampType using the optionally specified for-mat.

to_utc_timestamp(timestamp, tz) This is a common function for databases supportingTIMESTAMP WITHOUT TIMEZONE.

transform(col, f) Returns an array of elements after applying a transfor-mation to each element in the input array.

transform_keys(col, f) Applies a function to every key-value pair in a map andreturns a map with the results of those applications asthe new keys for the pairs.

transform_values(col, f) Applies a function to every key-value pair in a map andreturns a map with the results of those applications asthe new values for the pairs.

translate(srcCol, matching, replace) A function translate any character in the srcCol by acharacter in matching.

trim(col) Trim the spaces from both ends for the specified stringcolumn.

trunc(date, format) Returns date truncated to the unit specified by the for-mat.

udf([f, returnType]) Creates a user defined function (UDF).unbase64(col) Decodes a BASE64 encoded string column and returns

it as a binary column.unhex(col) Inverse of hex.

continues on next page

3.1. Spark SQL 157

Page 162: pyspark Documentation

pyspark Documentation, Release master

Table 39 – continued from previous pageunix_timestamp([timestamp, format]) Convert time string with given pattern (‘yyyy-MM-dd

HH:mm:ss’, by default) to Unix time stamp (in sec-onds), using the default timezone and the default locale,return null if fail.

upper(col) Converts a string expression to upper case.var_pop(col) Aggregate function: returns the population variance of

the values in a group.var_samp(col) Aggregate function: returns the unbiased sample vari-

ance of the values in a group.variance(col) Aggregate function: alias for var_sampweekofyear(col) Extract the week number of a given date as integer.when(condition, value) Evaluates a list of conditions and returns one of multiple

possible result expressions.window(timeColumn, windowDuration[, . . . ]) Bucketize rows into one or more time windows given a

timestamp specifying column.xxhash64(*cols) Calculates the hash code of given columns using the 64-

bit variant of the xxHash algorithm, and returns the re-sult as a long column.

year(col) Extract the year of a given date as integer.years(col) Partition transform function: A transform for times-

tamps and dates to partition data into years.zip_with(col1, col2, f) Merge two given arrays, element-wise, into a single ar-

ray using a function.

pyspark.sql.functions.abs

pyspark.sql.functions.abs(col)Computes the absolute value.

New in version 1.3.

pyspark.sql.functions.acos

pyspark.sql.functions.acos(col)New in version 1.4.0.

Returns

Column inverse cosine of col, as if computed by java.lang.Math.acos()

pyspark.sql.functions.acosh

pyspark.sql.functions.acosh(col)Computes inverse hyperbolic cosine of the input column.

New in version 3.1.0.

Returns

Column

158 Chapter 3. API Reference

Page 163: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.add_months

pyspark.sql.functions.add_months(start, months)Returns the date that is months months after start

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(add_months(df.dt, 1).alias('next_month')).collect()[Row(next_month=datetime.date(2015, 5, 8))]

pyspark.sql.functions.aggregate

pyspark.sql.functions.aggregate(col, zero, merge, finish=None)Applies a binary operator to an initial state and all elements in the array, and reduces this to a single state. Thefinal state is converted into the final result by applying a finish function.

Both functions can use methods of pyspark.sql.Column, functions defined in pyspark.sql.functions and Scala UserDefinedFunctions. Python UserDefinedFunctions are not supported(SPARK-27052).

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

zero [Column or str] initial value. Name of column or expression

merge [function] a binary function (acc: Column, x: Column) -> Column... re-turning expression of the same type as zero

finish [function] an optional unary function (x: Column) -> Column: ... used to con-vert accumulated value.

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values→˓"))>>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).→˓show()+----+| sum|+----+|42.0|+----+

3.1. Spark SQL 159

Page 164: pyspark Documentation

pyspark Documentation, Release master

>>> def merge(acc, x):... count = acc.count + 1... sum = acc.sum + x... return struct(count.alias("count"), sum.alias("sum"))>>> df.select(... aggregate(... "values",... struct(lit(0).alias("count"), lit(0.0).alias("sum")),... merge,... lambda acc: acc.sum / acc.count,... ).alias("mean")... ).show()+----+|mean|+----+| 8.4|+----+

pyspark.sql.functions.approxCountDistinct

pyspark.sql.functions.approxCountDistinct(col, rsd=None)Deprecated since version 2.1.0: Use approx_count_distinct() instead.

New in version 1.3.

pyspark.sql.functions.approx_count_distinct

pyspark.sql.functions.approx_count_distinct(col, rsd=None)Aggregate function: returns a new Column for approximate distinct count of column col.

New in version 2.1.0.

Parameters

col [Column or str]

rsd [float, optional] maximum relative standard deviation allowed (default = 0.05). For rsd <0.01, it is more efficient to use countDistinct()

Examples

>>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect()[Row(distinct_ages=2)]

160 Chapter 3. API Reference

Page 165: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.array

pyspark.sql.functions.array(*cols)Creates a new array column.

New in version 1.4.0.

Parameters

cols [Column or str] column names or Columns that have the same data type.

Examples

>>> df.select(array('age', 'age').alias("arr")).collect()[Row(arr=[2, 2]), Row(arr=[5, 5])]>>> df.select(array([df.age, df.age]).alias("arr")).collect()[Row(arr=[2, 2]), Row(arr=[5, 5])]

pyspark.sql.functions.array_contains

pyspark.sql.functions.array_contains(col, value)Collection function: returns null if the array is null, true if the array contains the given value, and false otherwise.

New in version 1.5.0.

Parameters

col [Column or str] name of column containing array

value : value or column to check for in array

Examples

>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])>>> df.select(array_contains(df.data, "a")).collect()[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]>>> df.select(array_contains(df.data, lit("a"))).collect()[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]

pyspark.sql.functions.array_distinct

pyspark.sql.functions.array_distinct(col)Collection function: removes duplicate values from the array.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

3.1. Spark SQL 161

Page 166: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])>>> df.select(array_distinct(df.data)).collect()[Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]

pyspark.sql.functions.array_except

pyspark.sql.functions.array_except(col1, col2)Collection function: returns an array of the elements in col1 but not in col2, without duplicates.

New in version 2.4.0.

Parameters

col1 [Column or str] name of column containing array

col2 [Column or str] name of column containing array

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])>>> df.select(array_except(df.c1, df.c2)).collect()[Row(array_except(c1, c2)=['b'])]

pyspark.sql.functions.array_intersect

pyspark.sql.functions.array_intersect(col1, col2)Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates.

New in version 2.4.0.

Parameters

col1 [Column or str] name of column containing array

col2 [Column or str] name of column containing array

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])>>> df.select(array_intersect(df.c1, df.c2)).collect()[Row(array_intersect(c1, c2)=['a', 'c'])]

162 Chapter 3. API Reference

Page 167: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.array_join

pyspark.sql.functions.array_join(col, delimiter, null_replacement=None)Concatenates the elements of column using the delimiter. Null values are replaced with null_replacement if set,otherwise they are ignored.

New in version 2.4.0.

Examples

>>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])>>> df.select(array_join(df.data, ",").alias("joined")).collect()[Row(joined='a,b,c'), Row(joined='a')]>>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()[Row(joined='a,b,c'), Row(joined='a,NULL')]

pyspark.sql.functions.array_max

pyspark.sql.functions.array_max(col)Collection function: returns the maximum value of the array.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])>>> df.select(array_max(df.data).alias('max')).collect()[Row(max=3), Row(max=10)]

pyspark.sql.functions.array_min

pyspark.sql.functions.array_min(col)Collection function: returns the minimum value of the array.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

3.1. Spark SQL 163

Page 168: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])>>> df.select(array_min(df.data).alias('min')).collect()[Row(min=1), Row(min=-1)]

pyspark.sql.functions.array_position

pyspark.sql.functions.array_position(col, value)Collection function: Locates the position of the first occurrence of the given value in the given array. Returnsnull if either of the arguments are null.

New in version 2.4.0.

Notes

The position is not zero based, but 1 based index. Returns 0 if the given value could not be found in the array.

Examples

>>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])>>> df.select(array_position(df.data, "a")).collect()[Row(array_position(data, a)=3), Row(array_position(data, a)=0)]

pyspark.sql.functions.array_remove

pyspark.sql.functions.array_remove(col, element)Collection function: Remove all elements that equal to element from the given array.

New in version 2.4.0.

Parameters

col [Column or str] name of column containing array

element : element to be removed from the array

Examples

>>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])>>> df.select(array_remove(df.data, 1)).collect()[Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])]

164 Chapter 3. API Reference

Page 169: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.array_repeat

pyspark.sql.functions.array_repeat(col, count)Collection function: creates an array containing a column repeated count times.

New in version 2.4.0.

Examples

>>> df = spark.createDataFrame([('ab',)], ['data'])>>> df.select(array_repeat(df.data, 3).alias('r')).collect()[Row(r=['ab', 'ab', 'ab'])]

pyspark.sql.functions.array_sort

pyspark.sql.functions.array_sort(col)Collection function: sorts the input array in ascending order. The elements of the input array must be orderable.Null elements will be placed at the end of the returned array.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])>>> df.select(array_sort(df.data).alias('r')).collect()[Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]

pyspark.sql.functions.array_union

pyspark.sql.functions.array_union(col1, col2)Collection function: returns an array of the elements in the union of col1 and col2, without duplicates.

New in version 2.4.0.

Parameters

col1 [Column or str] name of column containing array

col2 [Column or str] name of column containing array

3.1. Spark SQL 165

Page 170: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql import Row>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])>>> df.select(array_union(df.c1, df.c2)).collect()[Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])]

pyspark.sql.functions.arrays_overlap

pyspark.sql.functions.arrays_overlap(a1, a2)Collection function: returns true if the arrays contain any common non-null element; if not, returns null if boththe arrays are non-empty and any of them contains a null element; returns false otherwise.

New in version 2.4.0.

Examples

>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], [→˓'x', 'y'])>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()[Row(overlap=True), Row(overlap=False)]

pyspark.sql.functions.arrays_zip

pyspark.sql.functions.arrays_zip(*cols)Collection function: Returns a merged array of structs in which the N-th struct contains all N-th values of inputarrays.

New in version 2.4.0.

Parameters

cols [Column or str] columns of arrays to be merged.

Examples

>>> from pyspark.sql.functions import arrays_zip>>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])>>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()[Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3,→˓vals2=4)])]

166 Chapter 3. API Reference

Page 171: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.asc

pyspark.sql.functions.asc(col)Returns a sort expression based on the ascending order of the given column name.

New in version 1.3.

pyspark.sql.functions.asc_nulls_first

pyspark.sql.functions.asc_nulls_first(col)Returns a sort expression based on the ascending order of the given column name, and null values return beforenon-null values.

New in version 2.4.

pyspark.sql.functions.asc_nulls_last

pyspark.sql.functions.asc_nulls_last(col)Returns a sort expression based on the ascending order of the given column name, and null values appear afternon-null values.

New in version 2.4.

pyspark.sql.functions.ascii

pyspark.sql.functions.ascii(col)Computes the numeric value of the first character of the string column.

New in version 1.5.

pyspark.sql.functions.asin

pyspark.sql.functions.asin(col)New in version 1.3.0.

Returns

Column inverse sine of col, as if computed by java.lang.Math.asin()

pyspark.sql.functions.asinh

pyspark.sql.functions.asinh(col)Computes inverse hyperbolic sine of the input column.

New in version 3.1.0.

Returns

Column

3.1. Spark SQL 167

Page 172: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.assert_true

pyspark.sql.functions.assert_true(col, errMsg=None)Returns null if the input column is true; throws an exception with the provided error message otherwise.

New in version 3.1.0.

Examples

>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])>>> df.select(assert_true(df.a < df.b).alias('r')).collect()[Row(r=None)]>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])>>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect()[Row(r=None)]>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])>>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect()[Row(r=None)]

pyspark.sql.functions.atan

pyspark.sql.functions.atan(col)New in version 1.4.0.

Returns

Column inverse tangent of col, as if computed by java.lang.Math.atan()

pyspark.sql.functions.atanh

pyspark.sql.functions.atanh(col)Computes inverse hyperbolic tangent of the input column.

New in version 3.1.0.

Returns

Column

pyspark.sql.functions.atan2

pyspark.sql.functions.atan2(col1, col2)New in version 1.4.0.

Parameters

col1 [str, Column or float] coordinate on y-axis

col2 [str, Column or float] coordinate on x-axis

Returns

Column the theta component of the point (r, theta) in polar coordinates that corresponds to thepoint (x, y) in Cartesian coordinates, as if computed by java.lang.Math.atan2()

168 Chapter 3. API Reference

Page 173: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.avg

pyspark.sql.functions.avg(col)Aggregate function: returns the average of the values in a group.

New in version 1.3.

pyspark.sql.functions.base64

pyspark.sql.functions.base64(col)Computes the BASE64 encoding of a binary column and returns it as a string column.

New in version 1.5.

pyspark.sql.functions.bin

pyspark.sql.functions.bin(col)Returns the string representation of the binary value of the given column.

New in version 1.5.0.

Examples

>>> df.select(bin(df.age).alias('c')).collect()[Row(c='10'), Row(c='101')]

pyspark.sql.functions.bitwiseNOT

pyspark.sql.functions.bitwiseNOT(col)Computes bitwise not.

New in version 1.4.

pyspark.sql.functions.broadcast

pyspark.sql.functions.broadcast(df)Marks a DataFrame as small enough for use in broadcast joins.

New in version 1.6.

pyspark.sql.functions.bround

pyspark.sql.functions.bround(col, scale=0)Round the given value to scale decimal places using HALF_EVEN rounding mode if scale >= 0 or at integralpart when scale < 0.

New in version 2.0.0.

3.1. Spark SQL 169

Page 174: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).→˓collect()[Row(r=2.0)]

pyspark.sql.functions.bucket

pyspark.sql.functions.bucket(numBuckets, col)Partition transform function: A transform for any type that partitions by a hash of the input column.

New in version 3.1.0.

Notes

This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.

Examples

>>> df.writeTo("catalog.db.table").partitionedBy(... bucket(42, "ts")... ).createOrReplace()

pyspark.sql.functions.cbrt

pyspark.sql.functions.cbrt(col)Computes the cube-root of the given value.

New in version 1.4.

pyspark.sql.functions.ceil

pyspark.sql.functions.ceil(col)Computes the ceiling of the given value.

New in version 1.4.

pyspark.sql.functions.coalesce

pyspark.sql.functions.coalesce(*cols)Returns the first column that is not null.

New in version 1.4.0.

170 Chapter 3. API Reference

Page 175: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b"))>>> cDf.show()+----+----+| a| b|+----+----+|null|null|| 1|null||null| 2|+----+----+

>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()+--------------+|coalesce(a, b)|+--------------+| null|| 1|| 2|+--------------+

>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()+----+----+----------------+| a| b|coalesce(a, 0.0)|+----+----+----------------+|null|null| 0.0|| 1|null| 1.0||null| 2| 0.0|+----+----+----------------+

pyspark.sql.functions.col

pyspark.sql.functions.col(col)Returns a Column based on the given column name.’

New in version 1.3.

pyspark.sql.functions.collect_list

pyspark.sql.functions.collect_list(col)Aggregate function: returns a list of objects with duplicates.

New in version 1.6.0.

Notes

The function is non-deterministic because the order of collected results depends on the order of the rows whichmay be non-deterministic after a shuffle.

3.1. Spark SQL 171

Page 176: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))>>> df2.agg(collect_list('age')).collect()[Row(collect_list(age)=[2, 5, 5])]

pyspark.sql.functions.collect_set

pyspark.sql.functions.collect_set(col)Aggregate function: returns a set of objects with duplicate elements eliminated.

New in version 1.6.0.

Notes

The function is non-deterministic because the order of collected results depends on the order of the rows whichmay be non-deterministic after a shuffle.

Examples

>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))>>> df2.agg(collect_set('age')).collect()[Row(collect_set(age)=[5, 2])]

pyspark.sql.functions.column

pyspark.sql.functions.column(col)Returns a Column based on the given column name.’

New in version 1.3.

pyspark.sql.functions.concat

pyspark.sql.functions.concat(*cols)Concatenates multiple input columns together into a single column. The function works with strings, binary andcompatible array columns.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])>>> df.select(concat(df.s, df.d).alias('s')).collect()[Row(s='abcd123')]

>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a',→˓ 'b', 'c'])>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]

172 Chapter 3. API Reference

Page 177: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.concat_ws

pyspark.sql.functions.concat_ws(sep, *cols)Concatenates multiple input string columns together into a single string column, using the given separator.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])>>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()[Row(s='abcd-123')]

pyspark.sql.functions.conv

pyspark.sql.functions.conv(col, fromBase, toBase)Convert a number in a string column from one base to another.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([("010101",)], ['n'])>>> df.select(conv(df.n, 2, 16).alias('hex')).collect()[Row(hex='15')]

pyspark.sql.functions.corr

pyspark.sql.functions.corr(col1, col2)Returns a new Column for the Pearson Correlation Coefficient for col1 and col2.

New in version 1.6.0.

Examples

>>> a = range(20)>>> b = [2 * x for x in range(20)]>>> df = spark.createDataFrame(zip(a, b), ["a", "b"])>>> df.agg(corr("a", "b").alias('c')).collect()[Row(c=1.0)]

3.1. Spark SQL 173

Page 178: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.cos

pyspark.sql.functions.cos(col)New in version 1.4.0.

Parameters

col [Column or str] angle in radians

Returns

Column cosine of the angle, as if computed by java.lang.Math.cos().

pyspark.sql.functions.cosh

pyspark.sql.functions.cosh(col)New in version 1.4.0.

Parameters

col [Column or str] hyperbolic angle

Returns

Column hyperbolic cosine of the angle, as if computed by java.lang.Math.cosh()

pyspark.sql.functions.count

pyspark.sql.functions.count(col)Aggregate function: returns the number of items in a group.

New in version 1.3.

pyspark.sql.functions.countDistinct

pyspark.sql.functions.countDistinct(col, *cols)Returns a new Column for distinct count of col or cols.

New in version 1.3.0.

Examples

>>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()[Row(c=2)]

>>> df.agg(countDistinct("age", "name").alias('c')).collect()[Row(c=2)]

174 Chapter 3. API Reference

Page 179: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.covar_pop

pyspark.sql.functions.covar_pop(col1, col2)Returns a new Column for the population covariance of col1 and col2.

New in version 2.0.0.

Examples

>>> a = [1] * 10>>> b = [1] * 10>>> df = spark.createDataFrame(zip(a, b), ["a", "b"])>>> df.agg(covar_pop("a", "b").alias('c')).collect()[Row(c=0.0)]

pyspark.sql.functions.covar_samp

pyspark.sql.functions.covar_samp(col1, col2)Returns a new Column for the sample covariance of col1 and col2.

New in version 2.0.0.

Examples

>>> a = [1] * 10>>> b = [1] * 10>>> df = spark.createDataFrame(zip(a, b), ["a", "b"])>>> df.agg(covar_samp("a", "b").alias('c')).collect()[Row(c=0.0)]

pyspark.sql.functions.crc32

pyspark.sql.functions.crc32(col)Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value as a bigint.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).→˓collect()[Row(crc32=2743272264)]

3.1. Spark SQL 175

Page 180: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.create_map

pyspark.sql.functions.create_map(*cols)Creates a new map column.

New in version 2.0.0.

Parameters

cols [Column or str] column names or Columns that are grouped as key-value pairs, e.g.(key1, value1, key2, value2, . . . ).

Examples

>>> df.select(create_map('name', 'age').alias("map")).collect()[Row(map={'Alice': 2}), Row(map={'Bob': 5})]>>> df.select(create_map([df.name, df.age]).alias("map")).collect()[Row(map={'Alice': 2}), Row(map={'Bob': 5})]

pyspark.sql.functions.cume_dist

pyspark.sql.functions.cume_dist()Window function: returns the cumulative distribution of values within a window partition, i.e. the fraction ofrows that are below the current row.

New in version 1.6.

pyspark.sql.functions.current_date

pyspark.sql.functions.current_date()Returns the current date at the start of query evaluation as a DateType column. All calls of current_date withinthe same query return the same value.

New in version 1.5.

pyspark.sql.functions.current_timestamp

pyspark.sql.functions.current_timestamp()Returns the current timestamp at the start of query evaluation as a TimestampType column. All calls ofcurrent_timestamp within the same query return the same value.

pyspark.sql.functions.date_add

pyspark.sql.functions.date_add(start, days)Returns the date that is days days after start

New in version 1.5.0.

176 Chapter 3. API Reference

Page 181: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(date_add(df.dt, 1).alias('next_date')).collect()[Row(next_date=datetime.date(2015, 4, 9))]

pyspark.sql.functions.date_format

pyspark.sql.functions.date_format(date, format)Converts a date/timestamp/string to a value of string in the format specified by the date format given by thesecond argument.

A pattern could be for instance dd.MM.yyyy and could return a string like ‘18.03.1993’. All pattern letters ofdatetime pattern. can be used.

New in version 1.5.0.

Notes

Whenever possible, use specialized functions like year.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect()[Row(date='04/08/2015')]

pyspark.sql.functions.date_sub

pyspark.sql.functions.date_sub(start, days)Returns the date that is days days before start

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect()[Row(prev_date=datetime.date(2015, 4, 7))]

pyspark.sql.functions.date_trunc

pyspark.sql.functions.date_trunc(format, timestamp)Returns timestamp truncated to the unit specified by the format.

New in version 2.3.0.

Parameters

format [str] ‘year’, ‘yyyy’, ‘yy’, ‘month’, ‘mon’, ‘mm’, ‘day’, ‘dd’, ‘hour’, ‘minute’, ‘second’,‘week’, ‘quarter’

3.1. Spark SQL 177

Page 182: pyspark Documentation

pyspark Documentation, Release master

timestamp [Column or str]

Examples

>>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])>>> df.select(date_trunc('year', df.t).alias('year')).collect()[Row(year=datetime.datetime(1997, 1, 1, 0, 0))]>>> df.select(date_trunc('mon', df.t).alias('month')).collect()[Row(month=datetime.datetime(1997, 2, 1, 0, 0))]

pyspark.sql.functions.datediff

pyspark.sql.functions.datediff(end, start)Returns the number of days from start to end.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])>>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()[Row(diff=32)]

pyspark.sql.functions.dayofmonth

pyspark.sql.functions.dayofmonth(col)Extract the day of the month of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(dayofmonth('dt').alias('day')).collect()[Row(day=8)]

pyspark.sql.functions.dayofweek

pyspark.sql.functions.dayofweek(col)Extract the day of the week of a given date as integer.

New in version 2.3.0.

178 Chapter 3. API Reference

Page 183: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(dayofweek('dt').alias('day')).collect()[Row(day=4)]

pyspark.sql.functions.dayofyear

pyspark.sql.functions.dayofyear(col)Extract the day of the year of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(dayofyear('dt').alias('day')).collect()[Row(day=98)]

pyspark.sql.functions.days

pyspark.sql.functions.days(col)Partition transform function: A transform for timestamps and dates to partition data into days.

New in version 3.1.0.

Notes

This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.

Examples

>>> df.writeTo("catalog.db.table").partitionedBy(... days("ts")... ).createOrReplace()

pyspark.sql.functions.decode

pyspark.sql.functions.decode(col, charset)Computes the first argument into a string from a binary using the provided character set (one of ‘US-ASCII’,‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).

New in version 1.5.

3.1. Spark SQL 179

Page 184: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.degrees

pyspark.sql.functions.degrees(col)Converts an angle measured in radians to an approximately equivalent angle measured in degrees.

New in version 2.1.0.

Parameters

col [Column or str] angle in radians

Returns

Column angle in degrees, as if computed by java.lang.Math.toDegrees()

pyspark.sql.functions.dense_rank

pyspark.sql.functions.dense_rank()Window function: returns the rank of rows within a window partition, without any gaps.

The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking sequence when thereare ties. That is, if you were ranking a competition using dense_rank and had three people tie for second place,you would say that all three were in second place and that the next person came in third. Rank would give mesequential numbers, making the person that came in third place (after the ties) would register as coming in fifth.

This is equivalent to the DENSE_RANK function in SQL.

New in version 1.6.

pyspark.sql.functions.desc

pyspark.sql.functions.desc(col)Returns a sort expression based on the descending order of the given column name.

New in version 1.3.

pyspark.sql.functions.desc_nulls_first

pyspark.sql.functions.desc_nulls_first(col)Returns a sort expression based on the descending order of the given column name, and null values appearbefore non-null values.

New in version 2.4.

pyspark.sql.functions.desc_nulls_last

pyspark.sql.functions.desc_nulls_last(col)Returns a sort expression based on the descending order of the given column name, and null values appear afternon-null values.

New in version 2.4.

180 Chapter 3. API Reference

Page 185: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.element_at

pyspark.sql.functions.element_at(col, extraction)Collection function: Returns element of array at given index in extraction if col is array. Returns value for thegiven key in extraction if col is map.

New in version 2.4.0.

Parameters

col [Column or str] name of column containing array or map

extraction : index to check for in array or key to check for in map

Notes

The position is not zero based, but 1 based index.

Examples

>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])>>> df.select(element_at(df.data, 1)).collect()[Row(element_at(data, 1)='a'), Row(element_at(data, 1)=None)]

>>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])>>> df.select(element_at(df.data, lit("a"))).collect()[Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]

pyspark.sql.functions.encode

pyspark.sql.functions.encode(col, charset)Computes the first argument into a binary from a string using the provided character set (one of ‘US-ASCII’,‘ISO-8859-1’, ‘UTF-8’, ‘UTF-16BE’, ‘UTF-16LE’, ‘UTF-16’).

New in version 1.5.

pyspark.sql.functions.exists

pyspark.sql.functions.exists(col, f)Returns whether a predicate holds for one or more elements in the array.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

f [function] (x: Column) -> Column: ... returning the Boolean expression. Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).

:return: a :class:`pyspark.sql.Column`

3.1. Spark SQL 181

Page 186: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key",→˓"values"))>>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show()+------------+|any_negative|+------------+| false|| true|+------------+

pyspark.sql.functions.exp

pyspark.sql.functions.exp(col)Computes the exponential of the given value.

New in version 1.4.

pyspark.sql.functions.explode

pyspark.sql.functions.explode(col)Returns a new row for each element in the given array or map. Uses the default column name col for elementsin the array and key and value for elements in the map unless specified otherwise.

New in version 1.4.0.

Examples

>>> from pyspark.sql import Row>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])>>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()[Row(anInt=1), Row(anInt=2), Row(anInt=3)]

>>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()+---+-----+|key|value|+---+-----+| a| b|+---+-----+

pyspark.sql.functions.explode_outer

pyspark.sql.functions.explode_outer(col)Returns a new row for each element in the given array or map. Unlike explode, if the array/map is null or emptythen null is produced. Uses the default column name col for elements in the array and key and value for elementsin the map unless specified otherwise.

New in version 2.3.0.

182 Chapter 3. API Reference

Page 187: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame(... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],... ("id", "an_array", "a_map")... )>>> df.select("id", "an_array", explode_outer("a_map")).show()+---+----------+----+-----+| id| an_array| key|value|+---+----------+----+-----+| 1|[foo, bar]| x| 1.0|| 2| []|null| null|| 3| null|null| null|+---+----------+----+-----+

>>> df.select("id", "a_map", explode_outer("an_array")).show()+---+----------+----+| id| a_map| col|+---+----------+----+| 1|{x -> 1.0}| foo|| 1|{x -> 1.0}| bar|| 2| {}|null|| 3| null|null|+---+----------+----+

pyspark.sql.functions.expm1

pyspark.sql.functions.expm1(col)Computes the exponential of the given value minus one.

New in version 1.4.

pyspark.sql.functions.expr

pyspark.sql.functions.expr(str)Parses the expression string into the column that it represents

New in version 1.5.0.

Examples

>>> df.select(expr("length(name)")).collect()[Row(length(name)=5), Row(length(name)=3)]

3.1. Spark SQL 183

Page 188: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.factorial

pyspark.sql.functions.factorial(col)Computes the factorial of the given value.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([(5,)], ['n'])>>> df.select(factorial(df.n).alias('f')).collect()[Row(f=120)]

pyspark.sql.functions.filter

pyspark.sql.functions.filter(col, f)Returns an array of elements for which a predicate holds in a given array.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

f [function] A function that returns the Boolean expression. Can take one of the followingforms:

• Unary (x: Column) -> Column: ...

• Binary (x: Column, i: Column) -> Column..., where the second argument isa 0-based index of the element.

and can use methods of pyspark.sql.Column, functions defined inpyspark.sql.functions and Scala UserDefinedFunctions. PythonUserDefinedFunctions are not supported (SPARK-27052).

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame(... [(1, ["2018-09-20", "2019-02-03", "2019-07-01", "2020-06-01"])],... ("key", "values")... )>>> def after_second_quarter(x):... return month(to_date(x)) > 6>>> df.select(... filter("values", after_second_quarter).alias("after_second_quarter")... ).show(truncate=False)+------------------------+|after_second_quarter |+------------------------+|[2018-09-20, 2019-07-01]|+------------------------+

184 Chapter 3. API Reference

Page 189: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.first

pyspark.sql.functions.first(col, ignorenulls=False)Aggregate function: returns the first value in a group.

The function by default returns the first values it sees. It will return the first non-null value it sees when ig-noreNulls is set to true. If all values are null, then null is returned.

New in version 1.3.0.

Notes

The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle.

pyspark.sql.functions.flatten

pyspark.sql.functions.flatten(col)Collection function: creates a single array from an array of arrays. If a structure of nested arrays is deeper thantwo levels, only one level of nesting is removed.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], [→˓'data'])>>> df.select(flatten(df.data).alias('r')).collect()[Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]

pyspark.sql.functions.floor

pyspark.sql.functions.floor(col)Computes the floor of the given value.

New in version 1.4.

pyspark.sql.functions.forall

pyspark.sql.functions.forall(col, f)Returns whether a predicate holds for every element in the array.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

3.1. Spark SQL 185

Page 190: pyspark Documentation

pyspark Documentation, Release master

f [function] (x: Column) -> Column: ... returning the Boolean expression. Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame(... [(1, ["bar"]), (2, ["foo", "bar"]), (3, ["foobar", "foo"])],... ("key", "values")... )>>> df.select(forall("values", lambda x: x.rlike("foo")).alias("all_foo")).show()+-------+|all_foo|+-------+| false|| false|| true|+-------+

pyspark.sql.functions.format_number

pyspark.sql.functions.format_number(col, d)Formats the number X to a format like ‘#,–#,–#.–’, rounded to d decimal places with HALF_EVEN round mode,and returns the result as a string.

New in version 1.5.0.

Parameters

col [Column or str] the column name of the numeric value to be formatted

d [int] the N decimal places

>>> spark.createDataFrame([(5,)], [‘a’]).select(format_number(‘a’, 4).alias(‘v’)).collect()

[Row(v=’5.0000’)]

pyspark.sql.functions.format_string

pyspark.sql.functions.format_string(format, *cols)Formats the arguments in printf-style and returns the result as a string column.

New in version 1.5.0.

Parameters

format [str] string that can contain embedded format tags and used as result column’s value

cols [Column or str] column names or Columns to be used in formatting

186 Chapter 3. API Reference

Page 191: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([(5, "hello")], ['a', 'b'])>>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()[Row(v='5 hello')]

pyspark.sql.functions.from_csv

pyspark.sql.functions.from_csv(col, schema, options=None)Parses a column containing a CSV string to a row with the specified schema. Returns null, in the case of anunparseable string.

New in version 3.0.0.

Parameters

col [Column or str] string column in CSV format

schema :class:`Column` or str a string with schema in DDL format to use when parsing theCSV column.

options [dict, optional] options to control parsing. accepts the same options as the CSV data-source

Examples

>>> data = [("1,2,3",)]>>> df = spark.createDataFrame(data, ("value",))>>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect()[Row(csv=Row(a=1, b=2, c=3))]>>> value = data[0][0]>>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()[Row(csv=Row(_c0=1, _c1=2, _c2=3))]>>> data = [(" abc",)]>>> df = spark.createDataFrame(data, ("value",))>>> options = {'ignoreLeadingWhiteSpace': True}>>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect()[Row(csv=Row(s='abc'))]

pyspark.sql.functions.from_json

pyspark.sql.functions.from_json(col, schema, options=None)Parses a column containing a JSON string into a MapType with StringType as keys type, StructType orArrayType with the specified schema. Returns null, in the case of an unparseable string.

New in version 2.1.0.

Parameters

col [Column or str] string column in json format

schema [DataType or str] a StructType or ArrayType of StructType to use when parsing thejson column.

Changed in version 2.3: the DDL-formatted string is also supported for schema.

3.1. Spark SQL 187

Page 192: pyspark Documentation

pyspark Documentation, Release master

options [dict, optional] options to control parsing. accepts the same options as the json data-source

Examples

>>> from pyspark.sql.types import *>>> data = [(1, '''{"a": 1}''')]>>> schema = StructType([StructField("a", IntegerType())])>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=Row(a=1))]>>> df.select(from_json(df.value, "a INT").alias("json")).collect()[Row(json=Row(a=1))]>>> df.select(from_json(df.value, "MAP<STRING,INT>").alias("json")).collect()[Row(json={'a': 1})]>>> data = [(1, '''[{"a": 1}]''')]>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=[Row(a=1)])]>>> schema = schema_of_json(lit('''{"a": 0}'''))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=Row(a=None))]>>> data = [(1, '''[1, 2, 3]''')]>>> schema = ArrayType(IntegerType())>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(from_json(df.value, schema).alias("json")).collect()[Row(json=[1, 2, 3])]

pyspark.sql.functions.from_unixtime

pyspark.sql.functions.from_unixtime(timestamp, format='yyyy-MM-dd HH:mm:ss')Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing thetimestamp of that moment in the current system time zone in the given format.

New in version 1.5.0.

Examples

>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")>>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time'])>>> time_df.select(from_unixtime('unix_time').alias('ts')).collect()[Row(ts='2015-04-08 00:00:00')]>>> spark.conf.unset("spark.sql.session.timeZone")

188 Chapter 3. API Reference

Page 193: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.from_utc_timestamp

pyspark.sql.functions.from_utc_timestamp(timestamp, tz)This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function takesa timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and renders that timestamp asa timestamp in the given time zone.

However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to the given timezone.

This function may return confusing result if the input is a string with timezone, e.g. ‘2018-03-13T06:18:23+00:00’. The reason is that, Spark firstly cast the string to timestamp according to the timezonein the string, and finally display the result by converting the timestamp to string according to the session localtimezone.

New in version 1.5.0.

Parameters

timestamp [Column or str] the column that contains timestamps

tz [Column or str] A string detailing the time zone ID that the input should be adjusted to. Itshould be in the format of either region-based zone IDs or zone offsets. Region IDs musthave the form ‘area/city’, such as ‘America/Los_Angeles’. Zone offsets must be in theformat ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also ‘UTC’ and ‘Z’ are supportedas aliases of ‘+00:00’. Other short names are not recommended to use because they can beambiguous.

Changed in version 2.4: tz can take a Column containing timezone ID strings.

Examples

>>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])>>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect()[Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))]>>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect()[Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))]

pyspark.sql.functions.get_json_object

pyspark.sql.functions.get_json_object(col, path)Extracts json object from a json string based on json path specified, and returns json string of the extracted jsonobject. It will return null if the input json string is invalid.

New in version 1.6.0.

Parameters

col [Column or str] string column in json format

path [str] path to the json object to extract

3.1. Spark SQL 189

Page 194: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1":→˓"value12"}''')]>>> df = spark.createDataFrame(data, ("key", "jstring"))>>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \... get_json_object(df.jstring, '$.f2').alias("c1") ).collect()[Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]

pyspark.sql.functions.greatest

pyspark.sql.functions.greatest(*cols)Returns the greatest value of the list of column names, skipping null values. This function takes at least 2parameters. It will return null iff all parameters are null.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])>>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()[Row(greatest=4)]

pyspark.sql.functions.grouping

pyspark.sql.functions.grouping(col)Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or not, returns 1for aggregated or 0 for not aggregated in the result set.

New in version 2.0.0.

Examples

>>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show()+-----+--------------+--------+| name|grouping(name)|sum(age)|+-----+--------------+--------+| null| 1| 7||Alice| 0| 2|| Bob| 0| 5|+-----+--------------+--------+

190 Chapter 3. API Reference

Page 195: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.grouping_id

pyspark.sql.functions.grouping_id(*cols)Aggregate function: returns the level of grouping, equals to

(grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + . . . + grouping(cn)

New in version 2.0.0.

Notes

The list of columns should match with grouping columns exactly, or empty (means all the grouping columns).

Examples

>>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()+-----+-------------+--------+| name|grouping_id()|sum(age)|+-----+-------------+--------+| null| 1| 7||Alice| 0| 2|| Bob| 0| 5|+-----+-------------+--------+

pyspark.sql.functions.hash

pyspark.sql.functions.hash(*cols)Calculates the hash code of given columns, and returns the result as an int column.

New in version 2.0.0.

Examples

>>> spark.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).→˓collect()[Row(hash=-757602832)]

pyspark.sql.functions.hex

pyspark.sql.functions.hex(col)Computes hex value of the given column, which could be pyspark.sql.types.StringType,pyspark.sql.types.BinaryType, pyspark.sql.types.IntegerType or pyspark.sql.types.LongType.

New in version 1.5.0.

3.1. Spark SQL 191

Page 196: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).→˓collect()[Row(hex(a)='414243', hex(b)='3')]

pyspark.sql.functions.hour

pyspark.sql.functions.hour(col)Extract the hours of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])>>> df.select(hour('ts').alias('hour')).collect()[Row(hour=13)]

pyspark.sql.functions.hours

pyspark.sql.functions.hours(col)Partition transform function: A transform for timestamps to partition data into hours.

New in version 3.1.0.

Notes

This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.

Examples

>>> df.writeTo("catalog.db.table").partitionedBy(... hours("ts")... ).createOrReplace()

pyspark.sql.functions.hypot

pyspark.sql.functions.hypot(col1, col2)Computes sqrt(a^2 + b^2) without intermediate overflow or underflow.

New in version 1.4.

192 Chapter 3. API Reference

Page 197: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.initcap

pyspark.sql.functions.initcap(col)Translate the first letter of each word to upper case in the sentence.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).→˓collect()[Row(v='Ab Cd')]

pyspark.sql.functions.input_file_name

pyspark.sql.functions.input_file_name()Creates a string column for the file name of the current Spark task.

New in version 1.6.

pyspark.sql.functions.instr

pyspark.sql.functions.instr(str, substr)Locate the position of the first occurrence of substr column in the given string. Returns null if either of thearguments are null.

New in version 1.5.0.

Notes

The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str.

>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(instr(df.s, 'b').alias('s')).collect()[Row(s=2)]

pyspark.sql.functions.isnan

pyspark.sql.functions.isnan(col)An expression that returns true iff the column is NaN.

New in version 1.6.0.

3.1. Spark SQL 193

Page 198: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a",→˓"b"))>>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect()[Row(r1=False, r2=False), Row(r1=True, r2=True)]

pyspark.sql.functions.isnull

pyspark.sql.functions.isnull(col)An expression that returns true iff the column is null.

New in version 1.6.0.

Examples

>>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b"))>>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect()[Row(r1=False, r2=False), Row(r1=True, r2=True)]

pyspark.sql.functions.json_tuple

pyspark.sql.functions.json_tuple(col, *fields)Creates a new row for a json column according to the given field names.

New in version 1.6.0.

Parameters

col [Column or str] string column in json format

fields [str] fields to extract

Examples

>>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1":→˓"value12"}''')]>>> df = spark.createDataFrame(data, ("key", "jstring"))>>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()[Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]

pyspark.sql.functions.kurtosis

pyspark.sql.functions.kurtosis(col)Aggregate function: returns the kurtosis of the values in a group.

New in version 1.6.

194 Chapter 3. API Reference

Page 199: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.lag

pyspark.sql.functions.lag(col, offset=1, default=None)Window function: returns the value that is offset rows before the current row, and defaultValue if there is lessthan offset rows before the current row. For example, an offset of one will return the previous row at any givenpoint in the window partition.

This is equivalent to the LAG function in SQL.

New in version 1.4.0.

Parameters

col [Column or str] name of column or expression

offset [int, optional] number of row to extend

default [optional] default value

pyspark.sql.functions.last

pyspark.sql.functions.last(col, ignorenulls=False)Aggregate function: returns the last value in a group.

The function by default returns the last values it sees. It will return the last non-null value it sees when ig-noreNulls is set to true. If all values are null, then null is returned.

New in version 1.3.0.

Notes

The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle.

pyspark.sql.functions.last_day

pyspark.sql.functions.last_day(date)Returns the last day of the month which the given date belongs to.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('1997-02-10',)], ['d'])>>> df.select(last_day(df.d).alias('date')).collect()[Row(date=datetime.date(1997, 2, 28))]

3.1. Spark SQL 195

Page 200: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.lead

pyspark.sql.functions.lead(col, offset=1, default=None)Window function: returns the value that is offset rows after the current row, and defaultValue if there is less thanoffset rows after the current row. For example, an offset of one will return the next row at any given point in thewindow partition.

This is equivalent to the LEAD function in SQL.

New in version 1.4.0.

Parameters

col [Column or str] name of column or expression

offset [int, optional] number of row to extend

default [optional] default value

pyspark.sql.functions.least

pyspark.sql.functions.least(*cols)Returns the least value of the list of column names, skipping null values. This function takes at least 2 parame-ters. It will return null iff all parameters are null.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])>>> df.select(least(df.a, df.b, df.c).alias("least")).collect()[Row(least=1)]

pyspark.sql.functions.length

pyspark.sql.functions.length(col)Computes the character length of string data or number of bytes of binary data. The length of character dataincludes the trailing spaces. The length of binary data includes binary zeros.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).→˓collect()[Row(length=4)]

196 Chapter 3. API Reference

Page 201: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.levenshtein

pyspark.sql.functions.levenshtein(left, right)Computes the Levenshtein distance of the two given strings.

New in version 1.5.0.

Examples

>>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])>>> df0.select(levenshtein('l', 'r').alias('d')).collect()[Row(d=3)]

pyspark.sql.functions.lit

pyspark.sql.functions.lit(col)Creates a Column of literal value.

New in version 1.3.0.

Examples

>>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)[Row(height=5, spark_user=True)]

pyspark.sql.functions.locate

pyspark.sql.functions.locate(substr, str, pos=1)Locate the position of the first occurrence of substr in a string column, after position pos.

New in version 1.5.0.

Parameters

substr [str] a string

str [Column or str] a Column of pyspark.sql.types.StringType

pos [int, optional] start position (zero based)

Notes

The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str.

3.1. Spark SQL 197

Page 202: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(locate('b', df.s, 1).alias('s')).collect()[Row(s=2)]

pyspark.sql.functions.log

pyspark.sql.functions.log(arg1, arg2=None)Returns the first argument-based logarithm of the second argument.

If there is only one argument, then this takes the natural logarithm of the argument.

New in version 1.5.0.

Examples

>>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).→˓collect()['0.30102', '0.69897']

>>> df.select(log(df.age).alias('e')).rdd.map(lambda l: str(l.e)[:7]).collect()['0.69314', '1.60943']

pyspark.sql.functions.log10

pyspark.sql.functions.log10(col)Computes the logarithm of the given value in Base 10.

New in version 1.4.

pyspark.sql.functions.log1p

pyspark.sql.functions.log1p(col)Computes the natural logarithm of the given value plus one.

New in version 1.4.

pyspark.sql.functions.log2

pyspark.sql.functions.log2(col)Returns the base-2 logarithm of the argument.

New in version 1.5.0.

198 Chapter 3. API Reference

Page 203: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect()[Row(log2=2.0)]

pyspark.sql.functions.lower

pyspark.sql.functions.lower(col)Converts a string expression to lower case.

New in version 1.5.

pyspark.sql.functions.lpad

pyspark.sql.functions.lpad(col, len, pad)Left-pad the string column to width len with pad.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(lpad(df.s, 6, '#').alias('s')).collect()[Row(s='##abcd')]

pyspark.sql.functions.ltrim

pyspark.sql.functions.ltrim(col)Trim the spaces from left end for the specified string value.

New in version 1.5.

pyspark.sql.functions.map_concat

pyspark.sql.functions.map_concat(*cols)Returns the union of all the given maps.

New in version 2.4.0.

Parameters

cols [Column or str] column names or Columns

3.1. Spark SQL 199

Page 204: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql.functions import map_concat>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2")>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)+------------------------+|map3 |+------------------------+|{1 -> a, 2 -> b, 3 -> c}|+------------------------+

pyspark.sql.functions.map_entries

pyspark.sql.functions.map_entries(col)Collection function: Returns an unordered array of all entries in the given map.

New in version 3.0.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> from pyspark.sql.functions import map_entries>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")>>> df.select(map_entries("data").alias("entries")).show()+----------------+| entries|+----------------+|[{1, a}, {2, b}]|+----------------+

pyspark.sql.functions.map_filter

pyspark.sql.functions.map_filter(col, f)Returns a map whose key-value pairs satisfy a predicate.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

f [function] a binary function (k: Column, v: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).

Returns

pyspark.sql.Column

200 Chapter 3. API Reference

Page 205: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([(1, {"foo": 42.0, "bar": 1.0, "baz": 32.0})], ("id→˓", "data"))>>> df.select(map_filter(... "data", lambda _, v: v > 30.0).alias("data_filtered")... ).show(truncate=False)+--------------------------+|data_filtered |+--------------------------+|{baz -> 32.0, foo -> 42.0}|+--------------------------+

pyspark.sql.functions.map_from_arrays

pyspark.sql.functions.map_from_arrays(col1, col2)Creates a new map from two arrays.

New in version 2.4.0.

Parameters

col1 [Column or str] name of column containing a set of keys. All elements should not be null

col2 [Column or str] name of column containing a set of values

Examples

>>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])>>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()+----------------+| map|+----------------+|{2 -> a, 5 -> b}|+----------------+

pyspark.sql.functions.map_from_entries

pyspark.sql.functions.map_from_entries(col)Collection function: Returns a map created from the given array of entries.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

3.1. Spark SQL 201

Page 206: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql.functions import map_from_entries>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")>>> df.select(map_from_entries("data").alias("map")).show()+----------------+| map|+----------------+|{1 -> a, 2 -> b}|+----------------+

pyspark.sql.functions.map_keys

pyspark.sql.functions.map_keys(col)Collection function: Returns an unordered array containing the keys of the map.

New in version 2.3.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> from pyspark.sql.functions import map_keys>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")>>> df.select(map_keys("data").alias("keys")).show()+------+| keys|+------+|[1, 2]|+------+

pyspark.sql.functions.map_values

pyspark.sql.functions.map_values(col)Collection function: Returns an unordered array containing the values of the map.

New in version 2.3.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> from pyspark.sql.functions import map_values>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")>>> df.select(map_values("data").alias("values")).show()+------+|values|+------+|[a, b]|+------+

202 Chapter 3. API Reference

Page 207: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.map_zip_with

pyspark.sql.functions.map_zip_with(col1, col2, f)Merge two given maps, key-wise into a single map using a function.

New in version 3.1.0.

Parameters

col1 [Column or str] name of the first column or expression

col2 [Column or str] name of the second column or expression

f [function] a ternary function (k: Column, v1: Column, v2: Column) ->Column... Can use methods of pyspark.sql.Column, functions definedin pyspark.sql.functions and Scala UserDefinedFunctions. PythonUserDefinedFunctions are not supported (SPARK-27052).

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame([... (1, {"IT": 24.0, "SALES": 12.00}, {"IT": 2.0, "SALES": 1.4})],... ("id", "base", "ratio")... )>>> df.select(map_zip_with(... "base", "ratio", lambda k, v1, v2: round(v1 * v2, 2)).alias("updated_data→˓")... ).show(truncate=False)+---------------------------+|updated_data |+---------------------------+|{SALES -> 16.8, IT -> 48.0}|+---------------------------+

pyspark.sql.functions.max

pyspark.sql.functions.max(col)Aggregate function: returns the maximum value of the expression in a group.

New in version 1.3.

pyspark.sql.functions.md5

pyspark.sql.functions.md5(col)Calculates the MD5 digest and returns the value as a 32 character hex string.

New in version 1.5.0.

3.1. Spark SQL 203

Page 208: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).→˓collect()[Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')]

pyspark.sql.functions.mean

pyspark.sql.functions.mean(col)Aggregate function: returns the average of the values in a group.

New in version 1.3.

pyspark.sql.functions.min

pyspark.sql.functions.min(col)Aggregate function: returns the minimum value of the expression in a group.

New in version 1.3.

pyspark.sql.functions.minute

pyspark.sql.functions.minute(col)Extract the minutes of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])>>> df.select(minute('ts').alias('minute')).collect()[Row(minute=8)]

pyspark.sql.functions.monotonically_increasing_id

pyspark.sql.functions.monotonically_increasing_id()A column that generates monotonically increasing 64-bit integers.

The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. The currentimplementation puts the partition ID in the upper 31 bits, and the record number within each partition in thelower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has lessthan 8 billion records.

New in version 1.6.0.

204 Chapter 3. API Reference

Page 209: pyspark Documentation

pyspark Documentation, Release master

Notes

The function is non-deterministic because its result depends on partition IDs.

As an example, consider a DataFrame with two partitions, each with 3 records. This expression would returnthe following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.

>>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).→˓toDF(['col1'])>>> df0.select(monotonically_increasing_id().alias('id')).collect()[Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593),→˓Row(id=8589934594)]

pyspark.sql.functions.month

pyspark.sql.functions.month(col)Extract the month of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(month('dt').alias('month')).collect()[Row(month=4)]

pyspark.sql.functions.months

pyspark.sql.functions.months(col)Partition transform function: A transform for timestamps and dates to partition data into months.

New in version 3.1.0.

Notes

This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.

Examples

>>> df.writeTo("catalog.db.table").partitionedBy(... months("ts")... ).createOrReplace()

3.1. Spark SQL 205

Page 210: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.months_between

pyspark.sql.functions.months_between(date1, date2, roundOff=True)Returns number of months between dates date1 and date2. If date1 is later than date2, then the result is positive.If date1 and date2 are on the same day of month, or both are the last day of month, returns an integer (time ofday will be ignored). The result is rounded off to 8 digits unless roundOff is set to False.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1',→˓'date2'])>>> df.select(months_between(df.date1, df.date2).alias('months')).collect()[Row(months=3.94959677)]>>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect()[Row(months=3.9495967741935485)]

pyspark.sql.functions.nanvl

pyspark.sql.functions.nanvl(col1, col2)Returns col1 if it is not NaN, or col2 if col1 is NaN.

Both inputs should be floating point columns (DoubleType or FloatType).

New in version 1.6.0.

Examples

>>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a",→˓"b"))>>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).→˓collect()[Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)]

pyspark.sql.functions.next_day

pyspark.sql.functions.next_day(date, dayOfWeek)Returns the first date which is later than the value of the date column.

Day of the week parameter is case insensitive, and accepts: “Mon”, “Tue”, “Wed”, “Thu”, “Fri”, “Sat”,“Sun”.

New in version 1.5.0.

206 Chapter 3. API Reference

Page 211: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([('2015-07-27',)], ['d'])>>> df.select(next_day(df.d, 'Sun').alias('date')).collect()[Row(date=datetime.date(2015, 8, 2))]

pyspark.sql.functions.nth_value

pyspark.sql.functions.nth_value(col, offset, ignoreNulls=False)Window function: returns the value that is the offsetth row of the window frame (counting from 1), and null ifthe size of window frame is less than offset rows.

It will return the offsetth non-null value it sees when ignoreNulls is set to true. If all values are null, then null isreturned.

This is equivalent to the nth_value function in SQL.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

offset [int, optional] number of row to use as the value

ignoreNulls [bool, optional] indicates the Nth value should skip null in the determination ofwhich row to use

pyspark.sql.functions.ntile

pyspark.sql.functions.ntile(n)Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window partition. For example,if n is 4, the first quarter of the rows will get value 1, the second quarter will get 2, the third quarter will get 3,and the last quarter will get 4.

This is equivalent to the NTILE function in SQL.

New in version 1.4.0.

Parameters

n [int] an integer

pyspark.sql.functions.overlay

pyspark.sql.functions.overlay(src, replace, pos, len=- 1)Overlay the specified portion of src with replace, starting from byte position pos of src and proceeding for lenbytes.

New in version 3.0.0.

3.1. Spark SQL 207

Page 212: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([("SPARK_SQL", "CORE")], ("x", "y"))>>> df.select(overlay("x", "y", 7).alias("overlayed")).show()+----------+| overlayed|+----------+|SPARK_CORE|+----------+

pyspark.sql.functions.pandas_udf

pyspark.sql.functions.pandas_udf(f=None, returnType=None, functionType=None)Creates a pandas user defined function (a.k.a. vectorized user defined function).

Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandasto work with the data, which allows vectorized operations. A Pandas UDF is defined using the pandas_udf asa decorator or to wrap the function, and no additional configuration is required. A Pandas UDF behaves as aregular PySpark function API in general.

New in version 2.3.0.

Parameters

f [function, optional] user-defined function. A python function if used as a standalone function

returnType [pyspark.sql.types.DataType or str, optional] the return type of the user-defined function. The value can be either a pyspark.sql.types.DataType object ora DDL-formatted type string.

functionType [int, optional] an enum value in pyspark.sql.functions.PandasUDFType. Default: SCALAR. This parameter exists for compatibility. UsingPython type hints is encouraged.

See also:

pyspark.sql.GroupedData.agg

pyspark.sql.DataFrame.mapInPandas

pyspark.sql.GroupedData.applyInPandas

pyspark.sql.PandasCogroupedOps.applyInPandas

pyspark.sql.UDFRegistration.register

Notes

The user-defined functions do not support conditional expressions or short circuiting in boolean expressions andit ends up with being executed all internally. If the functions can fail on special rows, the workaround is toincorporate the condition into the functions.

The user-defined functions do not take keyword arguments on the calling side.

The data type of returned pandas.Series from the user-defined functions should be matched with defined re-turnType (see types.to_arrow_type() and types.from_arrow_type()). When there is mismatchbetween them, Spark might do conversion on returned data. The conversion is not guaranteed to be correct andresults should be checked for accuracy by users.

208 Chapter 3. API Reference

Page 213: pyspark Documentation

pyspark Documentation, Release master

Currently, pyspark.sql.types.ArrayType of pyspark.sql.types.TimestampType andnested pyspark.sql.types.StructType are currently not supported as output types.

Examples

In order to use this API, customarily the below are imported:

>>> import pandas as pd>>> from pyspark.sql.functions import pandas_udf

From Spark 3.0 with Python 3.6+, Python type hints detect the function types as below:

>>> @pandas_udf(IntegerType())... def slen(s: pd.Series) -> pd.Series:... return s.str.len()

Prior to Spark 3.0, the pandas UDF used functionType to decide the execution type as below:

>>> from pyspark.sql.functions import PandasUDFType>>> from pyspark.sql.types import IntegerType>>> @pandas_udf(IntegerType(), PandasUDFType.SCALAR)... def slen(s):... return s.str.len()

It is preferred to specify type hints for the pandas UDF instead of specifying pandas UDF type via functionTypewhich will be deprecated in the future releases.

Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrameshould be used for its input or output type hint instead when the input or output column is of pyspark.sql.types.StructType. The following example shows a Pandas UDF which takes long column, string columnand struct column, and outputs a struct column. It requires the function to specify the type hints of pandas.Seriesand pandas.DataFrame as below:

>>> @pandas_udf("col1 string, col2 long")>>> def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:... s3['col2'] = s1 + s2.str.len()... return s3...>>> # Create a Spark DataFrame that has three columns including a struct column.... df = spark.createDataFrame(... [[1, "a string", ("a nested string",)]],... "long_col long, string_col string, struct_col struct<col1:string>")>>> df.printSchema()root|-- long_column: long (nullable = true)|-- string_column: string (nullable = true)|-- struct_column: struct (nullable = true)| |-- col1: string (nullable = true)>>> df.select(func("long_col", "string_col", "struct_col")).printSchema()|-- func(long_col, string_col, struct_col): struct (nullable = true)| |-- col1: string (nullable = true)| |-- col2: long (nullable = true)

In the following sections, it describes the combinations of the supported type hints. For simplicity, pan-das.DataFrame variant is omitted.

• Series to Series pandas.Series, . . . -> pandas.Series

3.1. Spark SQL 209

Page 214: pyspark Documentation

pyspark Documentation, Release master

The function takes one or more pandas.Series and outputs one pandas.Series. The output of thefunction should always be of the same length as the input.

>>> @pandas_udf("string")... def to_upper(s: pd.Series) -> pd.Series:... return s.str.upper()...>>> df = spark.createDataFrame([("John Doe",)], ("name",))>>> df.select(to_upper("name")).show()+--------------+|to_upper(name)|+--------------+| JOHN DOE|+--------------+

>>> @pandas_udf("first string, last string")... def split_expand(s: pd.Series) -> pd.DataFrame:... return s.str.split(expand=True)...>>> df = spark.createDataFrame([("John Doe",)], ("name",))>>> df.select(split_expand("name")).show()+------------------+|split_expand(name)|+------------------+| [John, Doe]|+------------------+

Note: The length of the input is not that of the whole input column, but is the length of an internalbatch used for each call to the function.

• Iterator of Series to Iterator of Series Iterator[pandas.Series] -> Iterator[pandas.Series]

The function takes an iterator of pandas.Series and outputs an iterator of pandas.Series. In this case,the created pandas UDF instance requires one input column when this is called as a PySpark col-umn. The length of the entire output from the function should be the same length of the entire input;therefore, it can prefetch the data from the input iterator as long as the lengths are the same.

It is also useful when the UDF execution requires initializing some states although internally it worksidentically as Series to Series case. The pseudocode below illustrates the example.

@pandas_udf("long")def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:

# Do some expensive initialization with a statestate = very_expensive_initialization()for x in iterator:

# Use that state for whole iterator.yield calculate_with_state(x, state)

df.select(calculate("value")).show()

>>> from typing import Iterator>>> @pandas_udf("long")... def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:... for s in iterator:... yield s + 1...

(continues on next page)

210 Chapter 3. API Reference

Page 215: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))>>> df.select(plus_one(df.v)).show()+-----------+|plus_one(v)|+-----------+| 2|| 3|| 4|+-----------+

Note: The length of each series is the length of a batch internally used.

• Iterator of Multiple Series to Iterator of Series Iterator[Tuple[pandas.Series, . . . ]] -> Itera-tor[pandas.Series]

The function takes an iterator of a tuple of multiple pandas.Series and outputs an iterator of pan-das.Series. In this case, the created pandas UDF instance requires input columns as many as the serieswhen this is called as a PySpark column. Otherwise, it has the same characteristics and restrictions asIterator of Series to Iterator of Series case.

>>> from typing import Iterator, Tuple>>> from pyspark.sql.functions import struct, col>>> @pandas_udf("long")... def multiply(iterator: Iterator[Tuple[pd.Series, pd.DataFrame]]) ->→˓Iterator[pd.Series]:... for s1, df in iterator:... yield s1 * df.v...>>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))>>> df.withColumn('output', multiply(col("v"), struct(col("v")))).show()+---+------+| v|output|+---+------+| 1| 1|| 2| 4|| 3| 9|+---+------+

Note: The length of each series is the length of a batch internally used.

• Series to Scalar pandas.Series, . . . -> Any

The function takes pandas.Series and returns a scalar value. The returnType should be a primitive datatype, and the returned scalar can be either a python primitive type, e.g., int or float or a numpy datatype, e.g., numpy.int64 or numpy.float64. Any should ideally be a specific scalar type accordingly.

>>> @pandas_udf("double")... def mean_udf(v: pd.Series) -> float:... return v.mean()...>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))>>> df.groupby("id").agg(mean_udf(df['v'])).show()

(continues on next page)

3.1. Spark SQL 211

Page 216: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+---+-----------+| id|mean_udf(v)|+---+-----------+| 1| 1.5|| 2| 6.0|+---+-----------+

This UDF can also be used as window functions as below:

>>> from pyspark.sql import Window>>> @pandas_udf("double")... def mean_udf(v: pd.Series) -> float:... return v.mean()...>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))>>> w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0)>>> df.withColumn('mean_v', mean_udf("v").over(w)).show()+---+----+------+| id| v|mean_v|+---+----+------+| 1| 1.0| 1.0|| 1| 2.0| 1.5|| 2| 3.0| 3.0|| 2| 5.0| 4.0|| 2|10.0| 7.5|+---+----+------+

Note: For performance reasons, the input series to window functions are not copied. Therefore,mutating the input series is not allowed and will cause incorrect results. For the same reason, usersshould also not rely on the index of the input series.

pyspark.sql.functions.percent_rank

pyspark.sql.functions.percent_rank()Window function: returns the relative rank (i.e. percentile) of rows within a window partition.

New in version 1.6.

pyspark.sql.functions.percentile_approx

pyspark.sql.functions.percentile_approx(col, percentage, accuracy=10000)Returns the approximate percentile of the numeric column col which is the smallest value in the ordered colvalues (sorted from least to greatest) such that no more than percentage of col values is less than the value orequal to that value. The value of percentage must be between 0.0 and 1.0.

The accuracy parameter (default: 10000) is a positive numeric literal which controls approximation accuracyat the cost of memory. Higher value of accuracy yields better accuracy, 1.0/accuracy is the relative error of theapproximation.

When percentage is an array, each value of the percentage array must be between 0.0 and 1.0. In this case,returns the approximate percentile array of column col at the given percentage array.

New in version 3.1.0.

212 Chapter 3. API Reference

Page 217: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> key = (col("id") % 3).alias("key")>>> value = (randn(42) + key * 10).alias("value")>>> df = spark.range(0, 1000, 1, 1).select(key, value)>>> df.select(... percentile_approx("value", [0.25, 0.5, 0.75], 1000000).alias("quantiles")... ).printSchema()root|-- quantiles: array (nullable = true)| |-- element: double (containsNull = false)

>>> df.groupBy("key").agg(... percentile_approx("value", 0.5, lit(1000000)).alias("median")... ).printSchema()root|-- key: long (nullable = true)|-- median: double (nullable = true)

pyspark.sql.functions.posexplode

pyspark.sql.functions.posexplode(col)Returns a new row for each element with position in the given array or map. Uses the default column namepos for position, and col for elements in the array and key and value for elements in the map unless specifiedotherwise.

New in version 2.1.0.

Examples

>>> from pyspark.sql import Row>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])>>> eDF.select(posexplode(eDF.intlist)).collect()[Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]

>>> eDF.select(posexplode(eDF.mapfield)).show()+---+---+-----+|pos|key|value|+---+---+-----+| 0| a| b|+---+---+-----+

pyspark.sql.functions.posexplode_outer

pyspark.sql.functions.posexplode_outer(col)Returns a new row for each element with position in the given array or map. Unlike posexplode, if the array/mapis null or empty then the row (null, null) is produced. Uses the default column name pos for position, and colfor elements in the array and key and value for elements in the map unless specified otherwise.

New in version 2.3.0.

3.1. Spark SQL 213

Page 218: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame(... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],... ("id", "an_array", "a_map")... )>>> df.select("id", "an_array", posexplode_outer("a_map")).show()+---+----------+----+----+-----+| id| an_array| pos| key|value|+---+----------+----+----+-----+| 1|[foo, bar]| 0| x| 1.0|| 2| []|null|null| null|| 3| null|null|null| null|+---+----------+----+----+-----+>>> df.select("id", "a_map", posexplode_outer("an_array")).show()+---+----------+----+----+| id| a_map| pos| col|+---+----------+----+----+| 1|{x -> 1.0}| 0| foo|| 1|{x -> 1.0}| 1| bar|| 2| {}|null|null|| 3| null|null|null|+---+----------+----+----+

pyspark.sql.functions.pow

pyspark.sql.functions.pow(col1, col2)Returns the value of the first argument raised to the power of the second argument.

New in version 1.4.

pyspark.sql.functions.quarter

pyspark.sql.functions.quarter(col)Extract the quarter of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(quarter('dt').alias('quarter')).collect()[Row(quarter=2)]

214 Chapter 3. API Reference

Page 219: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.radians

pyspark.sql.functions.radians(col)Converts an angle measured in degrees to an approximately equivalent angle measured in radians.

New in version 2.1.0.

Parameters

col [Column or str] angle in degrees

Returns

Column angle in radians, as if computed by java.lang.Math.toRadians()

pyspark.sql.functions.raise_error

pyspark.sql.functions.raise_error(errMsg)Throws an exception with the provided error message.

New in version 3.1.

pyspark.sql.functions.rand

pyspark.sql.functions.rand(seed=None)Generates a random column with independent and identically distributed (i.i.d.) samples uniformly distributedin [0.0, 1.0).

New in version 1.4.0.

Notes

The function is non-deterministic in general case.

Examples

>>> df.withColumn('rand', rand(seed=42) * 3).collect()[Row(age=2, name='Alice', rand=2.4052597283576684),Row(age=5, name='Bob', rand=2.3913904055683974)]

pyspark.sql.functions.randn

pyspark.sql.functions.randn(seed=None)Generates a column with independent and identically distributed (i.i.d.) samples from the standard normaldistribution.

New in version 1.4.0.

3.1. Spark SQL 215

Page 220: pyspark Documentation

pyspark Documentation, Release master

Notes

The function is non-deterministic in general case.

Examples

>>> df.withColumn('randn', randn(seed=42)).collect()[Row(age=2, name='Alice', randn=1.1027054481455365),Row(age=5, name='Bob', randn=0.7400395449950132)]

pyspark.sql.functions.rank

pyspark.sql.functions.rank()Window function: returns the rank of rows within a window partition.

The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking sequence when thereare ties. That is, if you were ranking a competition using dense_rank and had three people tie for second place,you would say that all three were in second place and that the next person came in third. Rank would give mesequential numbers, making the person that came in third place (after the ties) would register as coming in fifth.

This is equivalent to the RANK function in SQL.

New in version 1.6.

pyspark.sql.functions.regexp_extract

pyspark.sql.functions.regexp_extract(str, pattern, idx)Extract a specific group matched by a Java regex, from the specified string column. If the regex did not match,or the specified group did not match, an empty string is returned.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('100-200',)], ['str'])>>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect()[Row(d='100')]>>> df = spark.createDataFrame([('foo',)], ['str'])>>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect()[Row(d='')]>>> df = spark.createDataFrame([('aaaac',)], ['str'])>>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()[Row(d='')]

216 Chapter 3. API Reference

Page 221: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.regexp_replace

pyspark.sql.functions.regexp_replace(str, pattern, replacement)Replace all substrings of the specified string value that match regexp with rep.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('100-200',)], ['str'])>>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect()[Row(d='-----')]

pyspark.sql.functions.repeat

pyspark.sql.functions.repeat(col, n)Repeats a string column n times, and returns it as a new string column.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('ab',)], ['s',])>>> df.select(repeat(df.s, 3).alias('s')).collect()[Row(s='ababab')]

pyspark.sql.functions.reverse

pyspark.sql.functions.reverse(col)Collection function: returns a reversed string or an array with reverse order of elements.

New in version 1.5.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])>>> df.select(reverse(df.data).alias('s')).collect()[Row(s='LQS krapS')]>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])>>> df.select(reverse(df.data).alias('r')).collect()[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]

3.1. Spark SQL 217

Page 222: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.rint

pyspark.sql.functions.rint(col)Returns the double value that is closest in value to the argument and is equal to a mathematical integer.

New in version 1.4.

pyspark.sql.functions.round

pyspark.sql.functions.round(col, scale=0)Round the given value to scale decimal places using HALF_UP rounding mode if scale >= 0 or at integral partwhen scale < 0.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).→˓collect()[Row(r=3.0)]

pyspark.sql.functions.row_number

pyspark.sql.functions.row_number()Window function: returns a sequential number starting at 1 within a window partition.

New in version 1.6.

pyspark.sql.functions.rpad

pyspark.sql.functions.rpad(col, len, pad)Right-pad the string column to width len with pad.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(rpad(df.s, 6, '#').alias('s')).collect()[Row(s='abcd##')]

pyspark.sql.functions.rtrim

pyspark.sql.functions.rtrim(col)Trim the spaces from right end for the specified string value.

New in version 1.5.

218 Chapter 3. API Reference

Page 223: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.schema_of_csv

pyspark.sql.functions.schema_of_csv(csv, options=None)Parses a CSV string and infers its schema in DDL format.

New in version 3.0.0.

Parameters

csv [Column or str] a CSV string or a foldable string column containing a CSV string.

options [dict, optional] options to control parsing. accepts the same options as the CSV data-source

Examples

>>> df = spark.range(1)>>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect()[Row(csv='STRUCT<`_c0`: INT, `_c1`: STRING>')]>>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()[Row(csv='STRUCT<`_c0`: INT, `_c1`: STRING>')]

pyspark.sql.functions.schema_of_json

pyspark.sql.functions.schema_of_json(json, options=None)Parses a JSON string and infers its schema in DDL format.

New in version 2.4.0.

Parameters

json [Column or str] a JSON string or a foldable string column containing a JSON string.

options [dict, optional] options to control parsing. accepts the same options as the JSON data-source

Changed in version 3.0: It accepts options parameter to control schema inferring.

Examples

>>> df = spark.range(1)>>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()[Row(json='STRUCT<`a`: BIGINT>')]>>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'})>>> df.select(schema.alias("json")).collect()[Row(json='STRUCT<`a`: BIGINT>')]

3.1. Spark SQL 219

Page 224: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.second

pyspark.sql.functions.second(col)Extract the seconds of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])>>> df.select(second('ts').alias('second')).collect()[Row(second=15)]

pyspark.sql.functions.sequence

pyspark.sql.functions.sequence(start, stop, step=None)Generate a sequence of integers from start to stop, incrementing by step. If step is not set, incrementing by 1 ifstart is less than or equal to stop, otherwise -1.

New in version 2.4.0.

Examples

>>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2'))>>> df1.select(sequence('C1', 'C2').alias('r')).collect()[Row(r=[-2, -1, 0, 1, 2])]>>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3'))>>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect()[Row(r=[4, 2, 0, -2, -4])]

pyspark.sql.functions.sha1

pyspark.sql.functions.sha1(col)Returns the hex string result of SHA-1.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).→˓collect()[Row(hash='3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]

220 Chapter 3. API Reference

Page 225: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.sha2

pyspark.sql.functions.sha2(col, numBits)Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, and SHA-512).The numBits indicates the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0(which is equivalent to 256).

New in version 1.5.0.

Examples

>>> digests = df.select(sha2(df.name, 256).alias('s')).collect()>>> digests[0]Row(s='3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')>>> digests[1]Row(s='cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')

pyspark.sql.functions.shiftLeft

pyspark.sql.functions.shiftLeft(col, numBits)Shift the given value numBits left.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).→˓collect()[Row(r=42)]

pyspark.sql.functions.shiftRight

pyspark.sql.functions.shiftRight(col, numBits)(Signed) shift the given value numBits right.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).→˓collect()[Row(r=21)]

3.1. Spark SQL 221

Page 226: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.shiftRightUnsigned

pyspark.sql.functions.shiftRightUnsigned(col, numBits)Unsigned shift the given value numBits right.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([(-42,)], ['a'])>>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()[Row(r=9223372036854775787)]

pyspark.sql.functions.shuffle

pyspark.sql.functions.shuffle(col)Collection function: Generates a random permutation of the given array.

New in version 2.4.0.

Parameters

col [Column or str] name of column or expression

Notes

The function is non-deterministic.

Examples

>>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data'])>>> df.select(shuffle(df.data).alias('s')).collect()[Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])]

pyspark.sql.functions.signum

pyspark.sql.functions.signum(col)Computes the signum of the given value.

New in version 1.4.

pyspark.sql.functions.sin

pyspark.sql.functions.sin(col)New in version 1.4.0.

Parameters

col [Column or str]

Returns

Column sine of the angle, as if computed by java.lang.Math.sin()

222 Chapter 3. API Reference

Page 227: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.sinh

pyspark.sql.functions.sinh(col)New in version 1.4.0.

Parameters

col [Column or str] hyperbolic angle

Returns

Column hyperbolic sine of the given value, as if computed by java.lang.Math.sinh()

pyspark.sql.functions.size

pyspark.sql.functions.size(col)Collection function: returns the length of the array or map stored in the column.

New in version 1.5.0.

Parameters

col [Column or str] name of column or expression

Examples

>>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])>>> df.select(size(df.data)).collect()[Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]

pyspark.sql.functions.skewness

pyspark.sql.functions.skewness(col)Aggregate function: returns the skewness of the values in a group.

New in version 1.6.

pyspark.sql.functions.slice

pyspark.sql.functions.slice(x, start, length)Collection function: returns an array containing all the elements in x from index start (array indices start at 1,or from the end if start is negative) with the specified length.

New in version 2.4.0.

Parameters

x [Column or str] the array to be sliced

start [Column or int] the starting index

length [Column or int] the length of the slice

3.1. Spark SQL 223

Page 228: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])>>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()[Row(sliced=[2, 3]), Row(sliced=[5])]

pyspark.sql.functions.sort_array

pyspark.sql.functions.sort_array(col, asc=True)Collection function: sorts the input array in ascending or descending order according to the natural ordering ofthe array elements. Null elements will be placed at the beginning of the returned array in ascending order or atthe end of the returned array in descending order.

New in version 1.5.0.

Parameters

col [Column or str] name of column or expression

asc [bool, optional]

Examples

>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])>>> df.select(sort_array(df.data).alias('r')).collect()[Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()[Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]

pyspark.sql.functions.soundex

pyspark.sql.functions.soundex(col)Returns the SoundEx encoding for a string

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name'])>>> df.select(soundex(df.name).alias("soundex")).collect()[Row(soundex='P362'), Row(soundex='U612')]

pyspark.sql.functions.spark_partition_id

pyspark.sql.functions.spark_partition_id()A column for partition ID.

New in version 1.6.0.

224 Chapter 3. API Reference

Page 229: pyspark Documentation

pyspark Documentation, Release master

Notes

This is non deterministic because it depends on data partitioning and task scheduling.

Examples

>>> df.repartition(1).select(spark_partition_id().alias("pid")).collect()[Row(pid=0), Row(pid=0)]

pyspark.sql.functions.split

pyspark.sql.functions.split(str, pattern, limit=- 1)Splits str around matches of the given pattern.

New in version 1.5.0.

Parameters

str [Column or str] a string expression to split

pattern [str] a string representing a regular expression. The regex string should be a Java regularexpression.

limit [int, optional] an integer which controls the number of times pattern is applied.

• limit > 0: The resulting array’s length will not be more than limit, and theresulting array’s last entry will contain all input beyond the last matched pattern.

• limit <= 0: pattern will be applied as many times as possible, and the resultingarray can be of any size.

Changed in version 3.0: split now takes an optional limit field. If not provided, default limitvalue is -1.

Examples

>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])>>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()[Row(s=['one', 'twoBthreeC'])]>>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()[Row(s=['one', 'two', 'three', ''])]

pyspark.sql.functions.sqrt

pyspark.sql.functions.sqrt(col)Computes the square root of the specified float value.

New in version 1.3.

3.1. Spark SQL 225

Page 230: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.stddev

pyspark.sql.functions.stddev(col)Aggregate function: alias for stddev_samp.

New in version 1.6.

pyspark.sql.functions.stddev_pop

pyspark.sql.functions.stddev_pop(col)Aggregate function: returns population standard deviation of the expression in a group.

New in version 1.6.

pyspark.sql.functions.stddev_samp

pyspark.sql.functions.stddev_samp(col)Aggregate function: returns the unbiased sample standard deviation of the expression in a group.

New in version 1.6.

pyspark.sql.functions.struct

pyspark.sql.functions.struct(*cols)Creates a new struct column.

New in version 1.4.0.

Parameters

cols [list, set, str or Column] column names or Columns to contain in the output struct.

Examples

>>> df.select(struct('age', 'name').alias("struct")).collect()[Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))]>>> df.select(struct([df.age, df.name]).alias("struct")).collect()[Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))]

pyspark.sql.functions.substring

pyspark.sql.functions.substring(str, pos, len)Substring starts at pos and is of length len when str is String type or returns the slice of byte array that starts atpos in byte and is of length len when str is Binary type.

New in version 1.5.0.

226 Chapter 3. API Reference

Page 231: pyspark Documentation

pyspark Documentation, Release master

Notes

The position is not zero based, but 1 based index.

Examples

>>> df = spark.createDataFrame([('abcd',)], ['s',])>>> df.select(substring(df.s, 1, 2).alias('s')).collect()[Row(s='ab')]

pyspark.sql.functions.substring_index

pyspark.sql.functions.substring_index(str, delim, count)Returns the substring from string str before count occurrences of the delimiter delim. If count is positive,everything the left of the final delimiter (counting from left) is returned. If count is negative, every to the right ofthe final delimiter (counting from the right) is returned. substring_index performs a case-sensitive match whensearching for delim.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('a.b.c.d',)], ['s'])>>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()[Row(s='a.b')]>>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()[Row(s='b.c.d')]

pyspark.sql.functions.sum

pyspark.sql.functions.sum(col)Aggregate function: returns the sum of all values in the expression.

New in version 1.3.

pyspark.sql.functions.sumDistinct

pyspark.sql.functions.sumDistinct(col)Aggregate function: returns the sum of distinct values in the expression.

New in version 1.3.

3.1. Spark SQL 227

Page 232: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.tan

pyspark.sql.functions.tan(col)New in version 1.4.0.

Parameters

col [Column or str] angle in radians

Returns

Column tangent of the given value, as if computed by java.lang.Math.tan()

pyspark.sql.functions.tanh

pyspark.sql.functions.tanh(col)New in version 1.4.0.

Parameters

col [Column or str] hyperbolic angle

Returns

Column hyperbolic tangent of the given value as if computed by java.lang.Math.tanh()

pyspark.sql.functions.timestamp_seconds

pyspark.sql.functions.timestamp_seconds(col)New in version 3.1.0.

Examples

>>> from pyspark.sql.functions import timestamp_seconds>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")>>> time_df = spark.createDataFrame([(1230219000,)], ['unix_time'])>>> time_df.select(timestamp_seconds(time_df.unix_time).alias('ts')).show()+-------------------+| ts|+-------------------+|2008-12-25 07:30:00|+-------------------+>>> spark.conf.unset("spark.sql.session.timeZone")

pyspark.sql.functions.toDegrees

pyspark.sql.functions.toDegrees(col)Deprecated since version 2.1.0: Use degrees() instead.

New in version 1.4.

228 Chapter 3. API Reference

Page 233: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.toRadians

pyspark.sql.functions.toRadians(col)Deprecated since version 2.1.0: Use radians() instead.

New in version 1.4.

pyspark.sql.functions.to_csv

pyspark.sql.functions.to_csv(col, options=None)Converts a column containing a StructType into a CSV string. Throws an exception, in the case of anunsupported type.

New in version 3.0.0.

Parameters

col [Column or str] name of column containing a struct.

options: dict, optional options to control converting. accepts the same options as the CSVdatasource.

Examples

>>> from pyspark.sql import Row>>> data = [(1, Row(age=2, name='Alice'))]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_csv(df.value).alias("csv")).collect()[Row(csv='2,Alice')]

pyspark.sql.functions.to_date

pyspark.sql.functions.to_date(col, format=None)Converts a Column into pyspark.sql.types.DateType using the optionally specified format. Spec-ify formats according to datetime pattern. By default, it follows casting rules to pyspark.sql.types.DateType if the format is omitted. Equivalent to col.cast("date").

New in version 2.2.0.

Examples

>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_date(df.t).alias('date')).collect()[Row(date=datetime.date(1997, 2, 28))]

>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect()[Row(date=datetime.date(1997, 2, 28))]

3.1. Spark SQL 229

Page 234: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.to_json

pyspark.sql.functions.to_json(col, options=None)Converts a column containing a StructType, ArrayType or a MapType into a JSON string. Throws anexception, in the case of an unsupported type.

New in version 2.1.0.

Parameters

col [Column or str] name of column containing a struct, an array or a map.

options [dict, optional] options to control converting. accepts the same options as the JSONdatasource. Additionally the function supports the pretty option which enables pretty JSONgeneration.

Examples

>>> from pyspark.sql import Row>>> from pyspark.sql.types import *>>> data = [(1, Row(age=2, name='Alice'))]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='{"age":2,"name":"Alice"}')]>>> data = [(1, [Row(age=2, name='Alice'), Row(age=3, name='Bob')])]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')]>>> data = [(1, {"name": "Alice"})]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='{"name":"Alice"}')]>>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='[{"name":"Alice"},{"name":"Bob"}]')]>>> data = [(1, ["Alice", "Bob"])]>>> df = spark.createDataFrame(data, ("key", "value"))>>> df.select(to_json(df.value).alias("json")).collect()[Row(json='["Alice","Bob"]')]

pyspark.sql.functions.to_timestamp

pyspark.sql.functions.to_timestamp(col, format=None)Converts a Column into pyspark.sql.types.TimestampType using the optionally specified format.Specify formats according to datetime pattern. By default, it follows casting rules to pyspark.sql.types.TimestampType if the format is omitted. Equivalent to col.cast("timestamp").

New in version 2.2.0.

230 Chapter 3. API Reference

Page 235: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_timestamp(df.t).alias('dt')).collect()[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]

>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])>>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect()[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]

pyspark.sql.functions.to_utc_timestamp

pyspark.sql.functions.to_utc_timestamp(timestamp, tz)This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function takesa timestamp which is timezone-agnostic, and interprets it as a timestamp in the given timezone, and renders thattimestamp as a timestamp in UTC.

However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not timezone-agnostic. So in Spark this function just shift the timestamp value from the given timezone to UTC timezone.

This function may return confusing result if the input is a string with timezone, e.g. ‘2018-03-13T06:18:23+00:00’. The reason is that, Spark firstly cast the string to timestamp according to the timezonein the string, and finally display the result by converting the timestamp to string according to the session localtimezone.

New in version 1.5.0.

Parameters

timestamp [Column or str] the column that contains timestamps

tz [Column or str] A string detailing the time zone ID that the input should be adjusted to. Itshould be in the format of either region-based zone IDs or zone offsets. Region IDs musthave the form ‘area/city’, such as ‘America/Los_Angeles’. Zone offsets must be in theformat ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also ‘UTC’ and ‘Z’ are upportedas aliases of ‘+00:00’. Other short names are not recommended to use because they can beambiguous.

Changed in version 2.4.0: tz can take a Column containing timezone ID strings.

Examples

>>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])>>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect()[Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))]>>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect()[Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))]

3.1. Spark SQL 231

Page 236: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.transform

pyspark.sql.functions.transform(col, f)Returns an array of elements after applying a transformation to each element in the input array.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

f [function] a function that is applied to each element of the input array. Can take one of thefollowing forms:

• Unary (x: Column) -> Column: ...

• Binary (x: Column, i: Column) -> Column..., where the second argument isa 0-based index of the element.

and can use methods of pyspark.sql.Column, functions defined inpyspark.sql.functions and Scala UserDefinedFunctions. PythonUserDefinedFunctions are not supported (SPARK-27052).

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values"))>>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show()+------------+| doubled|+------------+|[2, 4, 6, 8]|+------------+

>>> def alternate(x, i):... return when(i % 2 == 0, x).otherwise(-x)>>> df.select(transform("values", alternate).alias("alternated")).show()+--------------+| alternated|+--------------+|[1, -2, 3, -4]|+--------------+

pyspark.sql.functions.transform_keys

pyspark.sql.functions.transform_keys(col, f)Applies a function to every key-value pair in a map and returns a map with the results of those applications asthe new keys for the pairs.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

232 Chapter 3. API Reference

Page 237: pyspark Documentation

pyspark Documentation, Release master

f [function] a binary function (k: Column, v: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data"))>>> df.select(transform_keys(... "data", lambda k, _: upper(k)).alias("data_upper")... ).show(truncate=False)+-------------------------+|data_upper |+-------------------------+|{BAR -> 2.0, FOO -> -2.0}|+-------------------------+

pyspark.sql.functions.transform_values

pyspark.sql.functions.transform_values(col, f)Applies a function to every key-value pair in a map and returns a map with the results of those applications asthe new values for the pairs.

New in version 3.1.0.

Parameters

col [Column or str] name of column or expression

f [function] a binary function (k: Column, v: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).

Returns

pyspark.sql.Column

Examples

>>> df = spark.createDataFrame([(1, {"IT": 10.0, "SALES": 2.0, "OPS": 24.0})], (→˓"id", "data"))>>> df.select(transform_values(... "data", lambda k, v: when(k.isin("IT", "OPS"), v + 10.0).otherwise(v)... ).alias("new_data")).show(truncate=False)+---------------------------------------+|new_data |+---------------------------------------+|{OPS -> 34.0, IT -> 20.0, SALES -> 2.0}|+---------------------------------------+

3.1. Spark SQL 233

Page 238: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.translate

pyspark.sql.functions.translate(srcCol, matching, replace)A function translate any character in the srcCol by a character in matching. The characters in replace is cor-responding to the characters in matching. The translate will happen when any character in the string matchingwith the character in the matching.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt",→˓"123") \... .alias('r')).collect()[Row(r='1a2s3ae')]

pyspark.sql.functions.trim

pyspark.sql.functions.trim(col)Trim the spaces from both ends for the specified string column.

New in version 1.5.

pyspark.sql.functions.trunc

pyspark.sql.functions.trunc(date, format)Returns date truncated to the unit specified by the format.

New in version 1.5.0.

Parameters

date [Column or str]

format [str] ‘year’, ‘yyyy’, ‘yy’ or ‘month’, ‘mon’, ‘mm’

Examples

>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])>>> df.select(trunc(df.d, 'year').alias('year')).collect()[Row(year=datetime.date(1997, 1, 1))]>>> df.select(trunc(df.d, 'mon').alias('month')).collect()[Row(month=datetime.date(1997, 2, 1))]

234 Chapter 3. API Reference

Page 239: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.udf

pyspark.sql.functions.udf(f=None, returnType=StringType)Creates a user defined function (UDF).

New in version 1.3.0.

Parameters

f [function] python function if used as a standalone function

returnType [pyspark.sql.types.DataType or str] the return type of the user-definedfunction. The value can be either a pyspark.sql.types.DataType object or a DDL-formatted type string.

Notes

The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocationsmay be eliminated or the function may even be invoked more times than it is present in the query. If yourfunction is not deterministic, call asNondeterministic on the user defined function. E.g.:

>>> from pyspark.sql.types import IntegerType>>> import random>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).→˓asNondeterministic()

The user-defined functions do not support conditional expressions or short circuiting in boolean expressions andit ends up with being executed all internally. If the functions can fail on special rows, the workaround is toincorporate the condition into the functions.

The user-defined functions do not take keyword arguments on the calling side.

Examples

>>> from pyspark.sql.types import IntegerType>>> slen = udf(lambda s: len(s), IntegerType())>>> @udf... def to_upper(s):... if s is not None:... return s.upper()...>>> @udf(returnType=IntegerType())... def add_one(x):... if x is not None:... return x + 1...>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).→˓show()+----------+--------------+------------+|slen(name)|to_upper(name)|add_one(age)|+----------+--------------+------------+| 8| JOHN DOE| 22|+----------+--------------+------------+

3.1. Spark SQL 235

Page 240: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.unbase64

pyspark.sql.functions.unbase64(col)Decodes a BASE64 encoded string column and returns it as a binary column.

New in version 1.5.

pyspark.sql.functions.unhex

pyspark.sql.functions.unhex(col)Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representationof number.

New in version 1.5.0.

Examples

>>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()[Row(unhex(a)=bytearray(b'ABC'))]

pyspark.sql.functions.unix_timestamp

pyspark.sql.functions.unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss')Convert time string with given pattern (‘yyyy-MM-dd HH:mm:ss’, by default) to Unix time stamp (in seconds),using the default timezone and the default locale, return null if fail.

if timestamp is None, then it returns current timestamp.

New in version 1.5.0.

Examples

>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")>>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).→˓collect()[Row(unix_time=1428476400)]>>> spark.conf.unset("spark.sql.session.timeZone")

pyspark.sql.functions.upper

pyspark.sql.functions.upper(col)Converts a string expression to upper case.

New in version 1.5.

236 Chapter 3. API Reference

Page 241: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.var_pop

pyspark.sql.functions.var_pop(col)Aggregate function: returns the population variance of the values in a group.

New in version 1.6.

pyspark.sql.functions.var_samp

pyspark.sql.functions.var_samp(col)Aggregate function: returns the unbiased sample variance of the values in a group.

New in version 1.6.

pyspark.sql.functions.variance

pyspark.sql.functions.variance(col)Aggregate function: alias for var_samp

New in version 1.6.

pyspark.sql.functions.weekofyear

pyspark.sql.functions.weekofyear(col)Extract the week number of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(weekofyear(df.dt).alias('week')).collect()[Row(week=15)]

pyspark.sql.functions.when

pyspark.sql.functions.when(condition, value)Evaluates a list of conditions and returns one of multiple possible result expressions. If Column.otherwise() is not invoked, None is returned for unmatched conditions.

New in version 1.4.0.

Parameters

condition [Column] a boolean Column expression.

value : a literal value, or a Column expression.

>>> df.select(when(df[‘age’] == 2, 3).otherwise(4).alias(“age”)).collect()

[Row(age=3), Row(age=4)]

>>> df.select(when(df.age == 2, df.age + 1).alias(“age”)).collect()

[Row(age=3), Row(age=None)]

3.1. Spark SQL 237

Page 242: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.window

pyspark.sql.functions.window(timeColumn, windowDuration, slideDuration=None, start-Time=None)

Bucketize rows into one or more time windows given a timestamp specifying column. Window starts are inclu-sive but the window ends are exclusive, e.g. 12:05 will be in the window [12:05,12:10) but not in [12:00,12:05).Windows can support microsecond precision. Windows in the order of months are not supported.

The time column must be of pyspark.sql.types.TimestampType.

Durations are provided as strings, e.g. ‘1 second’, ‘1 day 12 hours’, ‘2 minutes’. Valid interval strings are ‘week’,‘day’, ‘hour’, ‘minute’, ‘second’, ‘millisecond’, ‘microsecond’. If the slideDuration is not provided, thewindows will be tumbling windows.

The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start window intervals.For example, in order to have hourly tumbling windows that start 15 minutes past the hour, e.g. 12:15-13:15,13:15-14:15. . . provide startTime as 15 minutes.

The output column will be a struct called ‘window’ by default with the nested columns ‘start’ and ‘end’, where‘start’ and ‘end’ will be of pyspark.sql.types.TimestampType.

New in version 2.0.0.

Examples

>>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val")>>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum"))>>> w.select(w.window.start.cast("string").alias("start"),... w.window.end.cast("string").alias("end"), "sum").collect()[Row(start='2016-03-11 09:00:05', end='2016-03-11 09:00:10', sum=1)]

pyspark.sql.functions.xxhash64

pyspark.sql.functions.xxhash64(*cols)Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm, and returns theresult as a long column.

New in version 3.0.0.

Examples

>>> spark.createDataFrame([('ABC',)], ['a']).select(xxhash64('a').alias('hash')).→˓collect()[Row(hash=4105715581806190027)]

238 Chapter 3. API Reference

Page 243: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.functions.year

pyspark.sql.functions.year(col)Extract the year of a given date as integer.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])>>> df.select(year('dt').alias('year')).collect()[Row(year=2015)]

pyspark.sql.functions.years

pyspark.sql.functions.years(col)Partition transform function: A transform for timestamps and dates to partition data into years.

New in version 3.1.0.

Notes

This function can be used only in combination with partitionedBy() method of the DataFrameWriterV2.

Examples

>>> df.writeTo("catalog.db.table").partitionedBy(... years("ts")... ).createOrReplace()

pyspark.sql.functions.zip_with

pyspark.sql.functions.zip_with(col1, col2, f)Merge two given arrays, element-wise, into a single array using a function. If one array is shorter, nulls areappended at the end to match the length of the longer array, before applying the function.

New in version 3.1.0.

Parameters

col1 [Column or str] name of the first column or expression

col2 [Column or str] name of the second column or expression

f [function] a binary function (x1: Column, x2: Column) -> Column... Can usemethods of pyspark.sql.Column, functions defined in pyspark.sql.functionsand Scala UserDefinedFunctions. Python UserDefinedFunctions are not sup-ported (SPARK-27052).

Returns

pyspark.sql.Column

3.1. Spark SQL 239

Page 244: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([(1, [1, 3, 5, 8], [0, 2, 4, 6])], ("id", "xs", "ys→˓"))>>> df.select(zip_with("xs", "ys", lambda x, y: x ** y).alias("powers")).→˓show(truncate=False)+---------------------------+|powers |+---------------------------+|[1.0, 9.0, 625.0, 262144.0]|+---------------------------+

>>> df = spark.createDataFrame([(1, ["foo", "bar"], [1, 2, 3])], ("id", "xs", "ys→˓"))>>> df.select(zip_with("xs", "ys", lambda x, y: concat_ws("_", x, y)).alias("xs_ys→˓")).show()+-----------------+| xs_ys|+-----------------+|[foo_1, bar_2, 3]|+-----------------+

from_avro(data, jsonFormatSchema[, options]) Converts a binary column of Avro format into its corre-sponding catalyst value.

to_avro(data[, jsonFormatSchema]) Converts a column into binary of avro format.

pyspark.sql.avro.functions.from_avro

pyspark.sql.avro.functions.from_avro(data, jsonFormatSchema, options=None)Converts a binary column of Avro format into its corresponding catalyst value. The specified schema mustmatch the read data, otherwise the behavior is undefined: it may fail or return arbitrary result. To deserialize thedata with a compatible and evolved schema, the expected Avro schema can be set via the option avroSchema.

New in version 3.0.0.

Parameters

data [Column or str] the binary column.

jsonFormatSchema [str] the avro schema in JSON string format.

options [dict, optional] options to control how the Avro record is parsed.

Notes

Avro is built-in but external data source module since Spark 2.4. Please deploy the application as per thedeployment section of “Apache Avro Data Source Guide”.

240 Chapter 3. API Reference

Page 245: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql import Row>>> from pyspark.sql.avro.functions import from_avro, to_avro>>> data = [(1, Row(age=2, name='Alice'))]>>> df = spark.createDataFrame(data, ("key", "value"))>>> avroDf = df.select(to_avro(df.value).alias("avro"))>>> avroDf.collect()[Row(avro=bytearray(b'\x00\x00\x04\x00\nAlice'))]

>>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields":... [{"name":"avro","type":[{"type":"record","name":"value","namespace":→˓"topLevelRecord",... "fields":[{"name":"age","type":["long","null"]},... {"name":"name","type":["string","null"]}]},"null"]}]}'''>>> avroDf.select(from_avro(avroDf.avro, jsonFormatSchema).alias("value")).→˓collect()[Row(value=Row(avro=Row(age=2, name='Alice')))]

pyspark.sql.avro.functions.to_avro

pyspark.sql.avro.functions.to_avro(data, jsonFormatSchema='')Converts a column into binary of avro format.

New in version 3.0.0.

Parameters

data [Column or str] the data column.

jsonFormatSchema [str, optional] user-specified output avro schema in JSON string format.

Notes

Avro is built-in but external data source module since Spark 2.4. Please deploy the application as per thedeployment section of “Apache Avro Data Source Guide”.

Examples

>>> from pyspark.sql import Row>>> from pyspark.sql.avro.functions import to_avro>>> data = ['SPADES']>>> df = spark.createDataFrame(data, "string")>>> df.select(to_avro(df.value).alias("suite")).collect()[Row(suite=bytearray(b'\x00\x0cSPADES'))]

>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()[Row(suite=bytearray(b'\x02\x00'))]

3.1. Spark SQL 241

Page 246: pyspark Documentation

pyspark Documentation, Release master

3.1.9 Window

Window.currentRowWindow.orderBy(*cols) Creates a WindowSpec with the ordering defined.Window.partitionBy(*cols) Creates a WindowSpec with the partitioning defined.Window.rangeBetween(start, end) Creates a WindowSpec with the frame boundaries de-

fined, from start (inclusive) to end (inclusive).Window.rowsBetween(start, end) Creates a WindowSpec with the frame boundaries de-

fined, from start (inclusive) to end (inclusive).Window.unboundedFollowingWindow.unboundedPrecedingWindowSpec.orderBy(*cols) Defines the ordering columns in a WindowSpec.WindowSpec.partitionBy(*cols) Defines the partitioning columns in a WindowSpec.WindowSpec.rangeBetween(start, end) Defines the frame boundaries, from start (inclusive) to

end (inclusive).WindowSpec.rowsBetween(start, end) Defines the frame boundaries, from start (inclusive) to

end (inclusive).

pyspark.sql.Window.currentRow

Window.currentRow = 0

pyspark.sql.Window.orderBy

static Window.orderBy(*cols)Creates a WindowSpec with the ordering defined.

New in version 1.4.

pyspark.sql.Window.partitionBy

static Window.partitionBy(*cols)Creates a WindowSpec with the partitioning defined.

New in version 1.4.

pyspark.sql.Window.rangeBetween

static Window.rangeBetween(start, end)Creates a WindowSpec with the frame boundaries defined, from start (inclusive) to end (inclusive).

Both start and end are relative from the current row. For example, “0” means “current row”, while “-1” meansone off before the current row, and “5” means the five off after the current row.

We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.

A range-based boundary is based on the actual value of the ORDER BY expression(s). An offset is used to alterthe value of the ORDER BY expression, for instance if the current ORDER BY expression has a value of 10 andthe lower bound offset is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however putsa number of constraints on the ORDER BY expressions: there can be only one expression and this expression

242 Chapter 3. API Reference

Page 247: pyspark Documentation

pyspark Documentation, Release master

must have a numerical data type. An exception can be made when the offset is unbounded, because no valuemodification is needed, in this case multiple and non-numeric ORDER BY expression are allowed.

New in version 2.1.0.

Parameters

start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to max(-sys.maxsize, -9223372036854775808).

end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to min(sys.maxsize,9223372036854775807).

Examples

>>> from pyspark.sql import Window>>> from pyspark.sql import functions as func>>> from pyspark.sql import SQLContext>>> sc = SparkContext.getOrCreate()>>> sqlContext = SQLContext(sc)>>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")]>>> df = sqlContext.createDataFrame(tup, ["id", "category"])>>> window = Window.partitionBy("category").orderBy("id").rangeBetween(Window.→˓currentRow, 1)>>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category").→˓show()+---+--------+---+| id|category|sum|+---+--------+---+| 1| a| 4|| 1| a| 4|| 1| b| 3|| 2| a| 2|| 2| b| 5|| 3| b| 3|+---+--------+---+

pyspark.sql.Window.rowsBetween

static Window.rowsBetween(start, end)Creates a WindowSpec with the frame boundaries defined, from start (inclusive) to end (inclusive).

Both start and end are relative positions from the current row. For example, “0” means “current row”, while“-1” means the row before the current row, and “5” means the fifth row after the current row.

We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.

A row based boundary is based on the position of the row within the partition. An offset indicates the numberof rows above or below the current row, the frame for the current row starts or ends. For instance, given a rowbased sliding frame with a lower bound offset of -1 and a upper bound offset of +2. The frame for row withindex 5 would range from index 4 to index 7.

New in version 2.1.0.

Parameters

3.1. Spark SQL 243

Page 248: pyspark Documentation

pyspark Documentation, Release master

start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to -9223372036854775808.

end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to 9223372036854775807.

Examples

>>> from pyspark.sql import Window>>> from pyspark.sql import functions as func>>> from pyspark.sql import SQLContext>>> sc = SparkContext.getOrCreate()>>> sqlContext = SQLContext(sc)>>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")]>>> df = sqlContext.createDataFrame(tup, ["id", "category"])>>> window = Window.partitionBy("category").orderBy("id").rowsBetween(Window.→˓currentRow, 1)>>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category", "sum→˓").show()+---+--------+---+| id|category|sum|+---+--------+---+| 1| a| 2|| 1| a| 3|| 1| b| 3|| 2| a| 2|| 2| b| 5|| 3| b| 3|+---+--------+---+

pyspark.sql.Window.unboundedFollowing

Window.unboundedFollowing = 9223372036854775807

pyspark.sql.Window.unboundedPreceding

Window.unboundedPreceding = -9223372036854775808

pyspark.sql.WindowSpec.orderBy

WindowSpec.orderBy(*cols)Defines the ordering columns in a WindowSpec.

New in version 1.4.0.

Parameters

cols [str, Column or list] names of columns or expressions

244 Chapter 3. API Reference

Page 249: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.WindowSpec.partitionBy

WindowSpec.partitionBy(*cols)Defines the partitioning columns in a WindowSpec.

New in version 1.4.0.

Parameters

cols [str, Column or list] names of columns or expressions

pyspark.sql.WindowSpec.rangeBetween

WindowSpec.rangeBetween(start, end)Defines the frame boundaries, from start (inclusive) to end (inclusive).

Both start and end are relative from the current row. For example, “0” means “current row”, while “-1” meansone off before the current row, and “5” means the five off after the current row.

We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.

New in version 1.4.0.

Parameters

start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to max(-sys.maxsize, -9223372036854775808).

end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to min(sys.maxsize,9223372036854775807).

pyspark.sql.WindowSpec.rowsBetween

WindowSpec.rowsBetween(start, end)Defines the frame boundaries, from start (inclusive) to end (inclusive).

Both start and end are relative positions from the current row. For example, “0” means “current row”, while“-1” means the row before the current row, and “5” means the fifth row after the current row.

We recommend users use Window.unboundedPreceding, Window.unboundedFollowing, andWindow.currentRow to specify special boundary values, rather than using integral values directly.

New in version 1.4.0.

Parameters

start [int] boundary start, inclusive. The frame is unbounded if this is Window.unboundedPreceding, or any value less than or equal to max(-sys.maxsize, -9223372036854775808).

end [int] boundary end, inclusive. The frame is unbounded if this is Window.unboundedFollowing, or any value greater than or equal to min(sys.maxsize,9223372036854775807).

3.1. Spark SQL 245

Page 250: pyspark Documentation

pyspark Documentation, Release master

3.1.10 Grouping

GroupedData.agg(*exprs) Compute aggregates and returns the result as aDataFrame.

GroupedData.apply(udf) It is an alias of pyspark.sql.GroupedData.applyInPandas(); however, it takes apyspark.sql.functions.pandas_udf()whereas pyspark.sql.GroupedData.applyInPandas() takes a Python native function.

GroupedData.applyInPandas(func, schema) Maps each group of the current DataFrame using apandas udf and returns the result as a DataFrame.

GroupedData.avg(*cols) Computes average values for each numeric columns foreach group.

GroupedData.cogroup(other) Cogroups this group with another group so that we canrun cogrouped operations.

GroupedData.count() Counts the number of records for each group.GroupedData.max(*cols) Computes the max value for each numeric columns for

each group.GroupedData.mean(*cols) Computes average values for each numeric columns for

each group.GroupedData.min(*cols) Computes the min value for each numeric column for

each group.GroupedData.pivot(pivot_col[, values]) Pivots a column of the current DataFrame and per-

form the specified aggregation.GroupedData.sum(*cols) Compute the sum for each numeric columns for each

group.PandasCogroupedOps.applyInPandas(func,schema)

Applies a function to each cogroup using pandas andreturns the result as a DataFrame.

pyspark.sql.GroupedData.agg

GroupedData.agg(*exprs)Compute aggregates and returns the result as a DataFrame.

The available aggregate functions can be:

1. built-in aggregation functions, such as avg, max, min, sum, count

2. group aggregate pandas UDFs, created with pyspark.sql.functions.pandas_udf()

Note: There is no partial aggregation with group aggregate UDFs, i.e., a full shuffle is required. Also, allthe data of a group will be loaded into memory, so the user should be aware of the potential OOM risk ifdata is skewed and certain groups are too large to fit in memory.

See also:

pyspark.sql.functions.pandas_udf()

If exprs is a single dict mapping from string to string, then the key is the column to perform aggregation on,and the value is the aggregate function.

Alternatively, exprs can also be a list of aggregate Column expressions.

New in version 1.3.0.

246 Chapter 3. API Reference

Page 251: pyspark Documentation

pyspark Documentation, Release master

Parameters

exprs [dict] a dict mapping from column name (string) to aggregate functions (string), or a listof Column.

Notes

Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed in a single call to this func-tion.

Examples

>>> gdf = df.groupBy(df.name)>>> sorted(gdf.agg({"*": "count"}).collect())[Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=1)]

>>> from pyspark.sql import functions as F>>> sorted(gdf.agg(F.min(df.age)).collect())[Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=5)]

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType>>> @pandas_udf('int', PandasUDFType.GROUPED_AGG)... def min_udf(v):... return v.min()>>> sorted(gdf.agg(min_udf(df.age)).collect())[Row(name='Alice', min_udf(age)=2), Row(name='Bob', min_udf(age)=5)]

pyspark.sql.GroupedData.apply

GroupedData.apply(udf)It is an alias of pyspark.sql.GroupedData.applyInPandas(); however, it takes a pyspark.sql.functions.pandas_udf() whereas pyspark.sql.GroupedData.applyInPandas() takes aPython native function.

New in version 2.3.0.

Parameters

udf [pyspark.sql.functions.pandas_udf()] a grouped map user-defined functionreturned by pyspark.sql.functions.pandas_udf().

See also:

pyspark.sql.functions.pandas_udf

3.1. Spark SQL 247

Page 252: pyspark Documentation

pyspark Documentation, Release master

Notes

It is preferred to use pyspark.sql.GroupedData.applyInPandas() over this API. This API will bedeprecated in the future releases.

Examples

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],... ("id", "v"))>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)... def normalize(pdf):... v = pdf.v... return pdf.assign(v=(v - v.mean()) / v.std())>>> df.groupby("id").apply(normalize).show()+---+-------------------+| id| v|+---+-------------------+| 1|-0.7071067811865475|| 1| 0.7071067811865475|| 2|-0.8320502943378437|| 2|-0.2773500981126146|| 2| 1.1094003924504583|+---+-------------------+

pyspark.sql.GroupedData.applyInPandas

GroupedData.applyInPandas(func, schema)Maps each group of the current DataFrame using a pandas udf and returns the result as a DataFrame.

The function should take a pandas.DataFrame and return another pandas.DataFrame. For each group, allcolumns are passed together as a pandas.DataFrame to the user-function and the returned pandas.DataFrameare combined as a DataFrame.

The schema should be a StructType describing the schema of the returned pandas.DataFrame. The columnlabels of the returned pandas.DataFrame must either match the field names in the defined schema if specifiedas strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returnedpandas.DataFrame can be arbitrary.

New in version 3.0.0.

Parameters

func [function] a Python native function that takes a pandas.DataFrame, and outputs a pan-das.DataFrame.

schema [pyspark.sql.types.DataType or str] the return type of the func in PySpark.The value can be either a pyspark.sql.types.DataType object or a DDL-formattedtype string.

See also:

pyspark.sql.functions.pandas_udf

248 Chapter 3. API Reference

Page 253: pyspark Documentation

pyspark Documentation, Release master

Notes

This function requires a full shuffle. All the data of a group will be loaded into memory, so the user should beaware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.

If returning a new pandas.DataFrame constructed with a dictionary, it is recommended to explicitly indexthe columns by name to ensure the positions are correct, or alternatively use an OrderedDict. For exam-ple, pd.DataFrame({‘id’: ids, ‘a’: data}, columns=[‘id’, ‘a’]) or pd.DataFrame(OrderedDict([(‘id’, ids), (‘a’,data)])).

This API is experimental.

Examples

>>> import pandas as pd>>> from pyspark.sql.functions import pandas_udf, ceil>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],... ("id", "v"))>>> def normalize(pdf):... v = pdf.v... return pdf.assign(v=(v - v.mean()) / v.std())>>> df.groupby("id").applyInPandas(... normalize, schema="id long, v double").show()+---+-------------------+| id| v|+---+-------------------+| 1|-0.7071067811865475|| 1| 0.7071067811865475|| 2|-0.8320502943378437|| 2|-0.2773500981126146|| 2| 1.1094003924504583|+---+-------------------+

Alternatively, the user can pass a function that takes two arguments. In this case, the grouping key(s) will bepassed as the first argument and the data will be passed as the second argument. The grouping key(s) will bepassed as a tuple of numpy data types, e.g., numpy.int32 and numpy.float64. The data will still be passed in asa pandas.DataFrame containing all columns from the original Spark DataFrame. This is useful when the userdoes not want to hardcode grouping key(s) in the function.

>>> df = spark.createDataFrame(... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],... ("id", "v"))>>> def mean_func(key, pdf):... # key is a tuple of one numpy.int64, which is the value... # of 'id' for the current group... return pd.DataFrame([key + (pdf.v.mean(),)])>>> df.groupby('id').applyInPandas(... mean_func, schema="id long, v double").show()+---+---+| id| v|+---+---+| 1|1.5|| 2|6.0|+---+---+

3.1. Spark SQL 249

Page 254: pyspark Documentation

pyspark Documentation, Release master

>>> def sum_func(key, pdf):... # key is a tuple of two numpy.int64s, which is the values... # of 'id' and 'ceil(df.v / 2)' for the current group... return pd.DataFrame([key + (pdf.v.sum(),)])>>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(... sum_func, schema="id long, `ceil(v / 2)` long, v double").show()+---+-----------+----+| id|ceil(v / 2)| v|+---+-----------+----+| 2| 5|10.0|| 1| 1| 3.0|| 2| 3| 5.0|| 2| 2| 3.0|+---+-----------+----+

pyspark.sql.GroupedData.avg

GroupedData.avg(*cols)Computes average values for each numeric columns for each group.

mean() is an alias for avg().

New in version 1.3.0.

Parameters

cols [str] column names. Non-numeric columns are ignored.

Examples

>>> df.groupBy().avg('age').collect()[Row(avg(age)=3.5)]>>> df3.groupBy().avg('age', 'height').collect()[Row(avg(age)=3.5, avg(height)=82.5)]

pyspark.sql.GroupedData.cogroup

GroupedData.cogroup(other)Cogroups this group with another group so that we can run cogrouped operations.

New in version 3.0.0.

See PandasCogroupedOps for the operations that can be run.

pyspark.sql.GroupedData.count

GroupedData.count()Counts the number of records for each group.

New in version 1.3.0.

250 Chapter 3. API Reference

Page 255: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> sorted(df.groupBy(df.age).count().collect())[Row(age=2, count=1), Row(age=5, count=1)]

pyspark.sql.GroupedData.max

GroupedData.max(*cols)Computes the max value for each numeric columns for each group.

New in version 1.3.0.

Examples

>>> df.groupBy().max('age').collect()[Row(max(age)=5)]>>> df3.groupBy().max('age', 'height').collect()[Row(max(age)=5, max(height)=85)]

pyspark.sql.GroupedData.mean

GroupedData.mean(*cols)Computes average values for each numeric columns for each group.

mean() is an alias for avg().

New in version 1.3.0.

Parameters

cols [str] column names. Non-numeric columns are ignored.

Examples

>>> df.groupBy().mean('age').collect()[Row(avg(age)=3.5)]>>> df3.groupBy().mean('age', 'height').collect()[Row(avg(age)=3.5, avg(height)=82.5)]

pyspark.sql.GroupedData.min

GroupedData.min(*cols)Computes the min value for each numeric column for each group.

New in version 1.3.0.

Parameters

cols [str] column names. Non-numeric columns are ignored.

3.1. Spark SQL 251

Page 256: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.groupBy().min('age').collect()[Row(min(age)=2)]>>> df3.groupBy().min('age', 'height').collect()[Row(min(age)=2, min(height)=80)]

pyspark.sql.GroupedData.pivot

GroupedData.pivot(pivot_col, values=None)Pivots a column of the current DataFrame and perform the specified aggregation. There are two versions ofpivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that doesnot. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct valuesinternally.

New in version 1.6.0.

Parameters

pivot_col [str] Name of the column to pivot.

values : List of values that will be translated to columns in the output DataFrame.

Examples

# Compute the sum of earnings for each year by course with each course as a separate column

>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").→˓collect()[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000,→˓Java=30000)]

# Or without specifying column values (less efficient)

>>> df4.groupBy("year").pivot("course").sum("earnings").collect()[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000,→˓dotNET=48000)]>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").→˓collect()[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000,→˓dotNET=48000)]

pyspark.sql.GroupedData.sum

GroupedData.sum(*cols)Compute the sum for each numeric columns for each group.

New in version 1.3.0.

Parameters

cols [str] column names. Non-numeric columns are ignored.

252 Chapter 3. API Reference

Page 257: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df.groupBy().sum('age').collect()[Row(sum(age)=7)]>>> df3.groupBy().sum('age', 'height').collect()[Row(sum(age)=7, sum(height)=165)]

pyspark.sql.PandasCogroupedOps.applyInPandas

PandasCogroupedOps.applyInPandas(func, schema)Applies a function to each cogroup using pandas and returns the result as a DataFrame.

The function should take two pandas.DataFrames and return another pandas.DataFrame. For each side ofthe cogroup, all columns are passed together as a pandas.DataFrame to the user-function and the returnedpandas.DataFrame are combined as a DataFrame.

The schema should be a StructType describing the schema of the returned pandas.DataFrame. The columnlabels of the returned pandas.DataFrame must either match the field names in the defined schema if specifiedas strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returnedpandas.DataFrame can be arbitrary.

New in version 3.0.0.

Parameters

func [function] a Python native function that takes two pandas.DataFrames, and outputs a pan-das.DataFrame, or that takes one tuple (grouping keys) and two pandas DataFrames, andoutputs a pandas DataFrame.

schema [pyspark.sql.types.DataType or str] the return type of the func in PySpark.The value can be either a pyspark.sql.types.DataType object or a DDL-formattedtype string.

See also:

pyspark.sql.functions.pandas_udf

Notes

This function requires a full shuffle. All the data of a cogroup will be loaded into memory, so the user should beaware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.

If returning a new pandas.DataFrame constructed with a dictionary, it is recommended to explicitly indexthe columns by name to ensure the positions are correct, or alternatively use an OrderedDict. For exam-ple, pd.DataFrame({‘id’: ids, ‘a’: data}, columns=[‘id’, ‘a’]) or pd.DataFrame(OrderedDict([(‘id’, ids), (‘a’,data)])).

This API is experimental.

3.1. Spark SQL 253

Page 258: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.sql.functions import pandas_udf>>> df1 = spark.createDataFrame(... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2,→˓ 4.0)],... ("time", "id", "v1"))>>> df2 = spark.createDataFrame(... [(20000101, 1, "x"), (20000101, 2, "y")],... ("time", "id", "v2"))>>> def asof_join(l, r):... return pd.merge_asof(l, r, on="time", by="id")>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(... asof_join, schema="time int, id int, v1 double, v2 string"... ).show()+--------+---+---+---+| time| id| v1| v2|+--------+---+---+---+|20000101| 1|1.0| x||20000102| 1|3.0| x||20000101| 2|2.0| y||20000102| 2|4.0| y|+--------+---+---+---+

Alternatively, the user can define a function that takes three arguments. In this case, the grouping key(s) will bepassed as the first argument and the data will be passed as the second and third arguments. The grouping key(s)will be passed as a tuple of numpy data types, e.g., numpy.int32 and numpy.float64. The data will still be passedin as two pandas.DataFrame containing all columns from the original Spark DataFrames.

>>> def asof_join(k, l, r):... if k == (1,):... return pd.merge_asof(l, r, on="time", by="id")... else:... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(... asof_join, "time int, id int, v1 double, v2 string").show()+--------+---+---+---+| time| id| v1| v2|+--------+---+---+---+|20000101| 1|1.0| x||20000102| 1|3.0| x|+--------+---+---+---+

3.2 Structured Streaming

3.2.1 Core Classes

DataStreamReader(spark) Interface used to load a streaming DataFrame fromexternal storage systems (e.g.

DataStreamWriter(df) Interface used to write a streaming DataFrame to ex-ternal storage systems (e.g.

continues on next page

254 Chapter 3. API Reference

Page 259: pyspark Documentation

pyspark Documentation, Release master

Table 43 – continued from previous pageForeachBatchFunction(sql_ctx, func) This is the Python implementation of Java interface

‘ForeachBatchFunction’.StreamingQuery(jsq) A handle to a query that is executing continuously in the

background as new data arrives.StreamingQueryException(desc, stackTrace[,. . . ])

Exception that stopped a StreamingQuery .

StreamingQueryManager(jsqm) A class to manage all the StreamingQuery Stream-ingQueries active.

pyspark.sql.streaming.DataStreamReader

class pyspark.sql.streaming.DataStreamReader(spark)Interface used to load a streaming DataFrame from external storage systems (e.g. file systems, key-valuestores, etc). Use SparkSession.readStream to access this.

New in version 2.0.0.

Notes

This API is evolving.

Methods

csv(path[, schema, sep, encoding, quote, . . . ]) Loads a CSV file stream and returns the result as aDataFrame.

format(source) Specifies the input data source format.json(path[, schema, primitivesAsString, . . . ]) Loads a JSON file stream and returns the results as a

DataFrame.load([path, format, schema]) Loads a data stream from a data source and returns it

as a DataFrame.option(key, value) Adds an input option for the underlying data source.options(**options) Adds input options for the underlying data source.orc(path[, mergeSchema, pathGlobFilter, . . . ]) Loads a ORC file stream, returning the result as a

DataFrame.parquet(path[, mergeSchema, pathGlobFilter,. . . ])

Loads a Parquet file stream, returning the result as aDataFrame.

schema(schema) Specifies the input schema.text(path[, wholetext, lineSep, . . . ]) Loads a text file stream and returns a DataFrame

whose schema starts with a string column named“value”, and followed by partitioned columns if thereare any.

3.2. Structured Streaming 255

Page 260: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.DataStreamWriter

class pyspark.sql.streaming.DataStreamWriter(df)Interface used to write a streaming DataFrame to external storage systems (e.g. file systems, key-value stores,etc). Use DataFrame.writeStream to access this.

New in version 2.0.0.

Notes

This API is evolving.

Methods

foreach(f) Sets the output of the streaming query to be pro-cessed using the provided writer f.

foreachBatch(func) Sets the output of the streaming query to be pro-cessed using the provided function.

format(source) Specifies the underlying output data source.option(key, value) Adds an output option for the underlying data source.options(**options) Adds output options for the underlying data source.outputMode(outputMode) Specifies how data of a streaming

DataFrame/Dataset is written to a streamingsink.

partitionBy(*cols) Partitions the output by the given columns on the filesystem.

queryName(queryName) Specifies the name of the StreamingQuery thatcan be started with start().

start([path, format, outputMode, . . . ]) Streams the contents of the DataFrame to a datasource.

trigger(*args, **kwargs) Set the trigger for the stream query.

pyspark.sql.streaming.ForeachBatchFunction

class pyspark.sql.streaming.ForeachBatchFunction(sql_ctx, func)This is the Python implementation of Java interface ‘ForeachBatchFunction’. This wraps the user-defined ‘fore-achBatch’ function such that it can be called from the JVM when the query is active.

Methods

call(jdf, batch_id)

256 Chapter 3. API Reference

Page 261: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.StreamingQuery

class pyspark.sql.streaming.StreamingQuery(jsq)A handle to a query that is executing continuously in the background as new data arrives. All these methods arethread-safe.

New in version 2.0.0.

Notes

This API is evolving.

Methods

awaitTermination([timeout]) Waits for the termination of this query, either byquery.stop() or by an exception.

exception() New in version 2.1.0.

explain([extended]) Prints the (logical and physical) plans to the consolefor debugging purpose.

processAllAvailable() Blocks until all available data in the source has beenprocessed and committed to the sink.

stop() Stop this streaming query.

Attributes

id Returns the unique id of this query that persistsacross restarts from checkpoint data.

isActive Whether this streaming query is currently active ornot.

lastProgress Returns the most recentStreamingQueryProgress update of thisstreaming query or None if there were no progressupdates

name Returns the user-specified name of the query, or nullif not specified.

recentProgress Returns an array of the most recent [[Streaming-QueryProgress]] updates for this query.

runId Returns the unique id of this query that does not per-sist across restarts.

status Returns the current status of the query.

3.2. Structured Streaming 257

Page 262: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.StreamingQueryException

exception pyspark.sql.streaming.StreamingQueryException(desc, stackTrace,cause=None)

Exception that stopped a StreamingQuery .

pyspark.sql.streaming.StreamingQueryManager

class pyspark.sql.streaming.StreamingQueryManager(jsqm)A class to manage all the StreamingQuery StreamingQueries active.

New in version 2.0.0.

Notes

This API is evolving.

Methods

awaitAnyTermination([timeout]) Wait until any of the queries on the associated SQL-Context has terminated since the creation of the con-text, or since resetTerminated() was called.

get(id) Returns an active query from this SQLContext orthrows exception if an active query with this namedoesn’t exist.

resetTerminated() Forget about past terminated queries so thatawaitAnyTermination() can be used again towait for new terminations.

Attributes

active Returns a list of active queries associated with thisSQLContext

3.2.2 Input and Output

DataStreamReader.csv(path[, schema, sep, . . . ]) Loads a CSV file stream and returns the result as aDataFrame.

DataStreamReader.format(source) Specifies the input data source format.DataStreamReader.json(path[, schema, . . . ]) Loads a JSON file stream and returns the results as a

DataFrame.DataStreamReader.load([path, format,schema])

Loads a data stream from a data source and returns it asa DataFrame.

DataStreamReader.option(key, value) Adds an input option for the underlying data source.DataStreamReader.options(**options) Adds input options for the underlying data source.DataStreamReader.orc(path[, mergeSchema,. . . ])

Loads a ORC file stream, returning the result as aDataFrame.

continues on next page

258 Chapter 3. API Reference

Page 263: pyspark Documentation

pyspark Documentation, Release master

Table 51 – continued from previous pageDataStreamReader.parquet(path[, . . . ]) Loads a Parquet file stream, returning the result as a

DataFrame.DataStreamReader.schema(schema) Specifies the input schema.DataStreamReader.text(path[, wholetext, . . . ]) Loads a text file stream and returns a DataFrame

whose schema starts with a string column named“value”, and followed by partitioned columns if thereare any.

DataStreamWriter.foreach(f) Sets the output of the streaming query to be processedusing the provided writer f.

DataStreamWriter.foreachBatch(func) Sets the output of the streaming query to be processedusing the provided function.

DataStreamWriter.format(source) Specifies the underlying output data source.DataStreamWriter.option(key, value) Adds an output option for the underlying data source.DataStreamWriter.options(**options) Adds output options for the underlying data source.DataStreamWriter.outputMode(outputMode) Specifies how data of a streaming DataFrame/Dataset is

written to a streaming sink.DataStreamWriter.partitionBy(*cols) Partitions the output by the given columns on the file

system.DataStreamWriter.queryName(queryName) Specifies the name of the StreamingQuery that can

be started with start().DataStreamWriter.start([path, format, . . . ]) Streams the contents of the DataFrame to a data

source.DataStreamWriter.trigger(*args, **kwargs) Set the trigger for the stream query.

pyspark.sql.streaming.DataStreamReader.csv

DataStreamReader.csv(path, schema=None, sep=None, encoding=None, quote=None, escape=None,comment=None, header=None, inferSchema=None, ignoreLeadingWhiteS-pace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nan-Value=None, positiveInf=None, negativeInf=None, dateFormat=None,timestampFormat=None, maxColumns=None, maxCharsPerColumn=None,maxMalformedLogPerPartition=None, mode=None, columnNameOfCor-ruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,enforceSchema=None, emptyValue=None, locale=None, lineSep=None,pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHan-dling=None)

Loads a CSV file stream and returns the result as a DataFrame.

This function will go through the input once to determine the input schema if inferSchema is enabled. Toavoid going through the entire data once, disable inferSchema option or specify the schema explicitly usingschema.

Parameters

path [str or list] string, or list of strings, for input path(s).

schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).

sep [str, optional] sets a separator (one or more characters) for each field and value. If None isset, it uses the default value, ,.

encoding [str, optional] decodes the CSV files by the given encoding type. If None is set, ituses the default value, UTF-8.

3.2. Structured Streaming 259

Page 264: pyspark Documentation

pyspark Documentation, Release master

quote [str, optional sets a single character used for escaping quoted values where the] separatorcan be part of the value. If None is set, it uses the default value, ". If you would like to turnoff quotations, you need to set an empty string.

escape [str, optional] sets a single character used for escaping quotes inside an already quotedvalue. If None is set, it uses the default value, \.

comment [str, optional] sets a single character used for skipping lines beginning with this char-acter. By default (None), it is disabled.

header [str or bool, optional] uses the first line as names of columns. If None is set, it uses thedefault value, false.

inferSchema [str or bool, optional] infers the input schema automatically from data. It requiresone extra pass over the data. If None is set, it uses the default value, false.

enforceSchema [str or bool, optional] If it is set to true, the specified or inferred schemawill be forcibly applied to datasource files, and headers in CSV files will be ignored. If theoption is set to false, the schema will be validated against all headers in CSV files or thefirst header in RDD if the header option is set to true. Field names in the schema andcolumn names in CSV headers are checked by their positions taking into account spark.sql.caseSensitive. If None is set, true is used by default. Though the default valueis true, it is recommended to disable the enforceSchema option to avoid incorrectresults.

ignoreLeadingWhiteSpace [str or bool, optional] a flag indicating whether or not leadingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.

ignoreTrailingWhiteSpace [str or bool, optional] a flag indicating whether or not trailingwhitespaces from values being read should be skipped. If None is set, it uses the defaultvalue, false.

nullValue [str, optional] sets the string representation of a null value. If None is set, it uses thedefault value, empty string. Since 2.0.1, this nullValue param applies to all supportedtypes including the string type.

nanValue [str, optional] sets the string representation of a non-number value. If None is set, ituses the default value, NaN.

positiveInf [str, optional] sets the string representation of a positive infinity value. If None isset, it uses the default value, Inf.

negativeInf [str, optional] sets the string representation of a negative infinity value. If None isset, it uses the default value, Inf.

dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.

timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].

maxColumns [str or int, optional] defines a hard limit of how many columns a record can have.If None is set, it uses the default value, 20480.

maxCharsPerColumn [str or int, optional] defines the maximum number of characters allowedfor any given value being read. If None is set, it uses the default value, -1meaning unlimitedlength.

260 Chapter 3. API Reference

Page 265: pyspark Documentation

pyspark Documentation, Release master

maxMalformedLogPerPartition [str or int, optional] this parameter is no longer used sinceSpark 2.2.0. If specified, it is ignored.

mode [str, optional] allows a mode for dealing with corrupt records during parsing. If None isset, it uses the default value, PERMISSIVE.

• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. A record with less/more tokensthan schema is not a corrupted record to CSV. When it meets a record having fewer to-kens than the length of the schema, sets null to extra fields. When the record has moretokens than the length of the schema, it drops extra tokens.

• DROPMALFORMED: ignores the whole corrupted records.

• FAILFAST: throws an exception when it meets corrupted records.

columnNameOfCorruptRecord [str, optional] allows renaming the new field havingmalformed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.

multiLine [str or bool, optional] parse one record, which may span multiple lines. If None isset, it uses the default value, false.

charToEscapeQuoteEscaping [str, optional] sets a single character used for escaping the es-cape for the quote character. If None is set, the default value is escape character when escapeand quote characters are different, \0 otherwise.

emptyValue [str, optional] sets the string representation of an empty value. If None is set, ituses the default value, empty string.

locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.

lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \\r, \\r\\n and \\n. Maximum length is 1 character.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

unescapedQuoteHandling [str, optional] defines how the CsvParser will handle values withunescaped quotes. If None is set, it uses the default value, STOP_AT_DELIMITER.

• STOP_AT_CLOSING_QUOTE: If unescaped quotes are found in the input, accumulatethe quote character and proceed parsing the value as a quoted value, until a closing quoteis found.

• BACK_TO_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters of the currentparsed value until the delimiter is found. If no delimiter is found in the value, the parserwill continue accumulating characters from the input until a delimiter or line ending isfound.

3.2. Structured Streaming 261

Page 266: pyspark Documentation

pyspark Documentation, Release master

• STOP_AT_DELIMITER: If unescaped quotes are found in the input, consider the valueas an unquoted value. This will make the parser accumulate all characters until the de-limiter or a line ending is found in the input.

• STOP_AT_DELIMITER: If unescaped quotes are found in the input, the content parsedfor the given value will be skipped and the value set in nullValue will be produced instead.

• RAISE_ERROR: If unescaped quotes are found in the input, a TextParsingException willbe thrown.

.. versionadded:: 2.0.0

Notes

This API is evolving.

Examples

>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)>>> csv_sdf.isStreamingTrue>>> csv_sdf.schema == sdf_schemaTrue

pyspark.sql.streaming.DataStreamReader.format

DataStreamReader.format(source)Specifies the input data source format.

New in version 2.0.0.

Parameters

source [str] name of the data source, e.g. ‘json’, ‘parquet’.

Notes

This API is evolving.

Examples

>>> s = spark.readStream.format("text")

262 Chapter 3. API Reference

Page 267: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.DataStreamReader.json

DataStreamReader.json(path, schema=None, primitivesAsString=None, prefersDecimal=None,allowComments=None, allowUnquotedFieldNames=None, allowSingle-Quotes=None, allowNumericLeadingZero=None, allowBackslashEscapin-gAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None,dateFormat=None, timestampFormat=None, multiLine=None, allowUn-quotedControlChars=None, lineSep=None, locale=None, dropField-IfAllNull=None, encoding=None, pathGlobFilter=None, recursive-FileLookup=None, allowNonNumericNumbers=None)

Loads a JSON file stream and returns the results as a DataFrame.

JSON Lines (newline-delimited JSON) is supported by default. For JSON (one record per file), set themultiLine parameter to true.

If the schema parameter is not specified, this function goes through the input once to determine the inputschema.

New in version 2.0.0.

Parameters

path [str] string represents path to the JSON dataset, or RDD of Strings storing JSON objects.

schema [pyspark.sql.types.StructType or str, optional] an optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For examplecol0 INT, col1 DOUBLE).

primitivesAsString [str or bool, optional] infers all primitive values as a string type. If None isset, it uses the default value, false.

prefersDecimal [str or bool, optional] infers all floating-point values as a decimal type. If thevalues do not fit in decimal, then it infers them as doubles. If None is set, it uses the defaultvalue, false.

allowComments [str or bool, optional] ignores Java/C++ style comment in JSON records. IfNone is set, it uses the default value, false.

allowUnquotedFieldNames [str or bool, optional] allows unquoted JSON field names. If Noneis set, it uses the default value, false.

allowSingleQuotes [str or bool, optional] allows single quotes in addition to double quotes. IfNone is set, it uses the default value, true.

allowNumericLeadingZero [str or bool, optional] allows leading zeros in numbers (e.g.00012). If None is set, it uses the default value, false.

allowBackslashEscapingAnyCharacter [str or bool, optional] allows accepting quoting of allcharacter using backslash quoting mechanism. If None is set, it uses the default value,false.

mode [str, optional] allows a mode for dealing with corrupt records during parsing. If None isset, it uses the default value, PERMISSIVE.

• PERMISSIVE: when it meets a corrupted record, puts the malformed string into afield configured by columnNameOfCorruptRecord, and sets malformed fieldsto null. To keep corrupt records, an user can set a string type field namedcolumnNameOfCorruptRecord in an user-defined schema. If a schema does nothave the field, it drops corrupt records during parsing. When inferring a schema, it im-plicitly adds a columnNameOfCorruptRecord field in an output schema.

• DROPMALFORMED: ignores the whole corrupted records.

3.2. Structured Streaming 263

Page 268: pyspark Documentation

pyspark Documentation, Release master

• FAILFAST: throws an exception when it meets corrupted records.

columnNameOfCorruptRecord [str, optional] allows renaming the new field havingmalformed string created by PERMISSIVE mode. This overrides spark.sql.columnNameOfCorruptRecord. If None is set, it uses the value specified in spark.sql.columnNameOfCorruptRecord.

dateFormat [str, optional] sets the string that indicates a date format. Custom date formatsfollow the formats at datetime pattern. # noqa This applies to date type. If None is set, ituses the default value, yyyy-MM-dd.

timestampFormat [str, optional] sets the string that indicates a timestamp format. Custom dateformats follow the formats at datetime pattern. # noqa This applies to timestamp type. IfNone is set, it uses the default value, yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX].

multiLine [str or bool, optional] parse one record, which may span multiple lines, per file. IfNone is set, it uses the default value, false.

allowUnquotedControlChars [str or bool, optional] allows JSON Strings to contain unquotedcontrol characters (ASCII characters with value less than 32, including tab and line feedcharacters) or not.

lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \r, \r\n and \n.

locale [str, optional] sets a locale as language tag in IETF BCP 47 format. If None is set,it uses the default value, en-US. For instance, locale is used while parsing dates andtimestamps.

dropFieldIfAllNull [str or bool, optional] whether to ignore column of all null values or emptyarray/struct during schema inference. If None is set, it uses the default value, false.

encoding [str or bool, optional] allows to forcibly set one of standard basic or extended encod-ing for the JSON files. For example UTF-16BE, UTF-32LE. If None is set, the encoding ofinput JSON will be detected automatically when the multiLine option is set to true.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery. # noqa

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

allowNonNumericNumbers [str or bool, optional] allows JSON parser to recognize set of“Not-a-Number” (NaN) tokens as legal floating number values. If None is set, it uses thedefault value, true.

• +INF: for positive infinity, as well as alias of +Infinity and Infinity.

• -INF: for negative infinity, alias -Infinity.

• NaN: for other not-a-numbers, like result of division by zero.

264 Chapter 3. API Reference

Page 269: pyspark Documentation

pyspark Documentation, Release master

Notes

This API is evolving.

Examples

>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)>>> json_sdf.isStreamingTrue>>> json_sdf.schema == sdf_schemaTrue

pyspark.sql.streaming.DataStreamReader.load

DataStreamReader.load(path=None, format=None, schema=None, **options)Loads a data stream from a data source and returns it as a DataFrame.

New in version 2.0.0.

Parameters

path [str, optional] optional string for file-system backed data sources.

format [str, optional] optional string for format of the data source. Default to ‘parquet’.

schema [pyspark.sql.types.StructType or str, optional] optional pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For example col0INT, col1 DOUBLE).

**options [dict] all other string options

Notes

This API is evolving.

Examples

>>> json_sdf = spark.readStream.format("json") \... .schema(sdf_schema) \... .load(tempfile.mkdtemp())>>> json_sdf.isStreamingTrue>>> json_sdf.schema == sdf_schemaTrue

3.2. Structured Streaming 265

Page 270: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.DataStreamReader.option

DataStreamReader.option(key, value)Adds an input option for the underlying data source.

You can set the following option(s) for reading files:

• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

New in version 2.0.0.

Notes

This API is evolving.

Examples

>>> s = spark.readStream.option("x", 1)

pyspark.sql.streaming.DataStreamReader.options

DataStreamReader.options(**options)Adds input options for the underlying data source.

You can set the following option(s) for reading files:

• timeZone: sets the string that indicates a time zone ID to be used to parse timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

New in version 2.0.0.

266 Chapter 3. API Reference

Page 271: pyspark Documentation

pyspark Documentation, Release master

Notes

This API is evolving.

Examples

>>> s = spark.readStream.options(x="1", y=2)

pyspark.sql.streaming.DataStreamReader.orc

DataStreamReader.orc(path, mergeSchema=None, pathGlobFilter=None, recursive-FileLookup=None)

Loads a ORC file stream, returning the result as a DataFrame.

New in version 2.3.0.

Parameters

mergeSchema [str or bool, optional] sets whether we should merge schemas collected from allORC part-files. This will override spark.sql.orc.mergeSchema. The default valueis specified in spark.sql.orc.mergeSchema.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery.

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

Examples

>>> orc_sdf = spark.readStream.schema(sdf_schema).orc(tempfile.mkdtemp())>>> orc_sdf.isStreamingTrue>>> orc_sdf.schema == sdf_schemaTrue

pyspark.sql.streaming.DataStreamReader.parquet

DataStreamReader.parquet(path, mergeSchema=None, pathGlobFilter=None, recursive-FileLookup=None)

Loads a Parquet file stream, returning the result as a DataFrame.

New in version 2.0.0.

Parameters

mergeSchema [str or bool, optional] sets whether we should merge schemas collected fromall Parquet part-files. This will override spark.sql.parquet.mergeSchema. Thedefault value is specified in spark.sql.parquet.mergeSchema.

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery.

3.2. Structured Streaming 267

Page 272: pyspark Documentation

pyspark Documentation, Release master

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

Examples

>>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp())>>> parquet_sdf.isStreamingTrue>>> parquet_sdf.schema == sdf_schemaTrue

pyspark.sql.streaming.DataStreamReader.schema

DataStreamReader.schema(schema)Specifies the input schema.

Some data sources (e.g. JSON) can infer the input schema automatically from data. By specifying the schemahere, the underlying data source can skip the schema inference step, and thus speed up data loading.

New in version 2.0.0.

Parameters

schema [pyspark.sql.types.StructType or str] a pyspark.sql.types.StructType object or a DDL-formatted string (For example col0 INT, col1DOUBLE).

Notes

This API is evolving.

Examples

>>> s = spark.readStream.schema(sdf_schema)>>> s = spark.readStream.schema("col0 INT, col1 DOUBLE")

pyspark.sql.streaming.DataStreamReader.text

DataStreamReader.text(path, wholetext=False, lineSep=None, pathGlobFilter=None, recursive-FileLookup=None)

Loads a text file stream and returns a DataFrame whose schema starts with a string column named “value”,and followed by partitioned columns if there are any. The text files must be encoded as UTF-8.

By default, each line in the text file is a new row in the resulting DataFrame.

New in version 2.0.0.

Parameters

paths [str or list] string, or list of strings, for input path(s).

wholetext [str or bool, optional] if true, read each file from input path(s) as a single row.

lineSep [str, optional] defines the line separator that should be used for parsing. If None is set,it covers all \r, \r\n and \n.

268 Chapter 3. API Reference

Page 273: pyspark Documentation

pyspark Documentation, Release master

pathGlobFilter [str or bool, optional] an optional glob pattern to only include files with pathsmatching the pattern. The syntax follows org.apache.hadoop.fs.GlobFilter. It does notchange the behavior of partition discovery.

recursiveFileLookup [str or bool, optional] recursively scan a directory for files. Using thisoption disables partition discovery. # noqa

Notes

This API is evolving.

Examples

>>> text_sdf = spark.readStream.text(tempfile.mkdtemp())>>> text_sdf.isStreamingTrue>>> "value" in str(text_sdf.schema)True

pyspark.sql.streaming.DataStreamWriter.foreach

DataStreamWriter.foreach(f)Sets the output of the streaming query to be processed using the provided writer f. This is often used to writethe output of a streaming query to arbitrary storage systems. The processing logic can be specified in two ways.

1. A function that takes a row as input. This is a simple way to express your processing logic. Note thatthis does not allow you to deduplicate generated data when failures cause reprocessing of some inputdata. That would require you to specify the processing logic in the next way.

2. An object with a process method and optional open and close methods. The object can have thefollowing methods.

• open(partition_id, epoch_id): Optional method that initializes the processing(for example, open a connection, start a transaction, etc). Additionally, you can use thepartition_id and epoch_id to deduplicate regenerated data (discussed later).

• process(row): Non-optional method that processes each Row.

• close(error): Optional method that finalizes and cleans up (for example, close connec-tion, commit transaction, etc.) after all rows have been processed.

The object will be used by Spark in the following way.

• A single copy of this object is responsible of all the data generated by a single task in aquery. In other words, one instance is responsible for processing one partition of the datagenerated in a distributed manner.

• This object must be serializable because each task will get a fresh serialized-deserializedcopy of the provided object. Hence, it is strongly recommended that any initialization forwriting data (e.g. opening a connection or starting a transaction) is done after the open(. . . )method has been called, which signifies that the task is ready to generate data.

• The lifecycle of the methods are as follows.

For each partition with partition_id:

. . . For each batch/epoch of streaming data with epoch_id:

3.2. Structured Streaming 269

Page 274: pyspark Documentation

pyspark Documentation, Release master

. . . . . . . Method open(partitionId, epochId) is called.

. . . . . . . If open(...) returns true, for each row in the partition and batch/epoch,method process(row) is called.

. . . . . . . Method close(errorOrNull) is called with error (if any) seen whileprocessing rows.

Important points to note:

• The partitionId and epochId can be used to deduplicate generated data when failures causereprocessing of some input data. This depends on the execution mode of the query. If thestreaming query is being executed in the micro-batch mode, then every partition representedby a unique tuple (partition_id, epoch_id) is guaranteed to have the same data. Hence, (parti-tion_id, epoch_id) can be used to deduplicate and/or transactionally commit data and achieveexactly-once guarantees. However, if the streaming query is being executed in the continuousmode, then this guarantee does not hold and therefore should not be used for deduplication.

• The close() method (if exists) will be called if open() method exists and returns success-fully (irrespective of the return value), except if the Python crashes in the middle.

New in version 2.4.0.

Notes

This API is evolving.

Examples

>>> # Print every row using a function>>> def print_row(row):... print(row)...>>> writer = sdf.writeStream.foreach(print_row)>>> # Print every row using a object with process() method>>> class RowPrinter:... def open(self, partition_id, epoch_id):... print("Opened %d, %d" % (partition_id, epoch_id))... return True... def process(self, row):... print(row)... def close(self, error):... print("Closed with error: %s" % str(error))...>>> writer = sdf.writeStream.foreach(RowPrinter())

270 Chapter 3. API Reference

Page 275: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.DataStreamWriter.foreachBatch

DataStreamWriter.foreachBatch(func)Sets the output of the streaming query to be processed using the provided function. This is supported onlythe in the micro-batch execution modes (that is, when the trigger is not continuous). In every micro-batch, theprovided function will be called in every micro-batch with (i) the output rows as a DataFrame and (ii) the batchidentifier. The batchId can be used deduplicate and transactionally write the output (that is, the provided Dataset)to external systems. The output DataFrame is guaranteed to exactly same for the same batchId (assuming alloperations are deterministic in the query).

New in version 2.4.0.

Notes

This API is evolving.

Examples

>>> def func(batch_df, batch_id):... batch_df.collect()...>>> writer = sdf.writeStream.foreachBatch(func)

pyspark.sql.streaming.DataStreamWriter.format

DataStreamWriter.format(source)Specifies the underlying output data source.

New in version 2.0.0.

Parameters

source [str] string, name of the data source, which for now can be ‘parquet’.

Notes

This API is evolving.

Examples

>>> writer = sdf.writeStream.format('json')

3.2. Structured Streaming 271

Page 276: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.DataStreamWriter.option

DataStreamWriter.option(key, value)Adds an output option for the underlying data source.

You can set the following option(s) for writing files:

• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

New in version 2.0.0.

Notes

This API is evolving.

pyspark.sql.streaming.DataStreamWriter.options

DataStreamWriter.options(**options)Adds output options for the underlying data source.

You can set the following option(s) for writing files:

• timeZone: sets the string that indicates a time zone ID to be used to format timestamps in theJSON/CSV datasources or partition values. The following formats of timeZone are supported:

– Region-based zone ID: It should have the form ‘area/city’, such as ‘America/Los_Angeles’.

– Zone offset: It should be in the format ‘(+|-)HH:mm’, for example ‘-08:00’ or ‘+01:00’. Also‘UTC’ and ‘Z’ are supported as aliases of ‘+00:00’.

Other short names like ‘CST’ are not recommended to use because they can be ambiguous. If itisn’t set, the current value of the SQL config spark.sql.session.timeZone is used bydefault.

New in version 2.0.0.

Notes

This API is evolving.

272 Chapter 3. API Reference

Page 277: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.DataStreamWriter.outputMode

DataStreamWriter.outputMode(outputMode)Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.

New in version 2.0.0.

Options include:

• append: Only the new rows in the streaming DataFrame/Dataset will be written to the sink

• complete: All the rows in the streaming DataFrame/Dataset will be written to the sink every timethese is some updates

• update: only the rows that were updated in the streaming DataFrame/Dataset will be written to thesink every time there are some updates. If the query doesn’t contain aggregations, it will be equivalentto append mode.

Notes

This API is evolving.

Examples

>>> writer = sdf.writeStream.outputMode('append')

pyspark.sql.streaming.DataStreamWriter.partitionBy

DataStreamWriter.partitionBy(*cols)Partitions the output by the given columns on the file system.

If specified, the output is laid out on the file system similar to Hive’s partitioning scheme.

New in version 2.0.0.

Parameters

cols [str or list] name of columns

Notes

This API is evolving.

pyspark.sql.streaming.DataStreamWriter.queryName

DataStreamWriter.queryName(queryName)Specifies the name of the StreamingQuery that can be started with start(). This name must be uniqueamong all the currently active queries in the associated SparkSession.

New in version 2.0.0.

Parameters

queryName [str] unique name for the query

3.2. Structured Streaming 273

Page 278: pyspark Documentation

pyspark Documentation, Release master

Notes

This API is evolving.

Examples

>>> writer = sdf.writeStream.queryName('streaming_query')

pyspark.sql.streaming.DataStreamWriter.start

DataStreamWriter.start(path=None, format=None, outputMode=None, partitionBy=None, query-Name=None, **options)

Streams the contents of the DataFrame to a data source.

The data source is specified by the format and a set of options. If format is not specified, the default datasource configured by spark.sql.sources.default will be used.

New in version 2.0.0.

Parameters

path [str, optional] the path in a Hadoop supported file system

format [str, optional] the format used to save

outputMode [str, optional] specifies how data of a streaming DataFrame/Dataset is written to astreaming sink.

• append: Only the new rows in the streaming DataFrame/Dataset will be written to thesink

• complete: All the rows in the streaming DataFrame/Dataset will be written to the sinkevery time these is some updates

• update: only the rows that were updated in the streaming DataFrame/Dataset will bewritten to the sink every time there are some updates. If the query doesn’t contain aggre-gations, it will be equivalent to append mode.

partitionBy [str or list, optional] names of partitioning columns

queryName [str, optional] unique name for the query

**options [dict] All other string options. You may want to provide a checkpointLocation formost streams, however it is not required for a memory stream.

Notes

This API is evolving.

274 Chapter 3. API Reference

Page 279: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()>>> sq.isActiveTrue>>> sq.name'this_query'>>> sq.stop()>>> sq.isActiveFalse>>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start(... queryName='that_query', outputMode="append", format='memory')>>> sq.name'that_query'>>> sq.isActiveTrue>>> sq.stop()

pyspark.sql.streaming.DataStreamWriter.trigger

DataStreamWriter.trigger(*args, **kwargs)Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalentto setting the trigger to processingTime='0 seconds'.

New in version 2.0.0.

Parameters

processingTime [str, optional] a processing time interval as a string, e.g. ‘5 seconds’, ‘1minute’. Set a trigger that runs a microbatch query periodically based on the processingtime. Only one trigger can be set.

once [bool, optional] if set to True, set a trigger that processes only one batch of data in astreaming query then terminates the query. Only one trigger can be set.

continuous [str, optional] a time interval as a string, e.g. ‘5 seconds’, ‘1 minute’. Set a triggerthat runs a continuous query with a given checkpoint interval. Only one trigger can be set.

Notes

This API is evolving.

Examples

>>> # trigger the query for execution every 5 seconds>>> writer = sdf.writeStream.trigger(processingTime='5 seconds')>>> # trigger the query for just once batch of data>>> writer = sdf.writeStream.trigger(once=True)>>> # trigger the query for execution every 5 seconds>>> writer = sdf.writeStream.trigger(continuous='5 seconds')

3.2. Structured Streaming 275

Page 280: pyspark Documentation

pyspark Documentation, Release master

3.2.3 Query Management

StreamingQuery.awaitTermination([timeout]) Waits for the termination of this query, either byquery.stop() or by an exception.

StreamingQuery.exception() New in version 2.1.0.

StreamingQuery.explain([extended]) Prints the (logical and physical) plans to the console fordebugging purpose.

StreamingQuery.id Returns the unique id of this query that persists acrossrestarts from checkpoint data.

StreamingQuery.isActive Whether this streaming query is currently active or not.StreamingQuery.lastProgress Returns the most recent

StreamingQueryProgress update of thisstreaming query or None if there were no progressupdates

StreamingQuery.name Returns the user-specified name of the query, or null ifnot specified.

StreamingQuery.processAllAvailable() Blocks until all available data in the source has beenprocessed and committed to the sink.

StreamingQuery.recentProgress Returns an array of the most recent [[Streaming-QueryProgress]] updates for this query.

StreamingQuery.runId Returns the unique id of this query that does not persistacross restarts.

StreamingQuery.status Returns the current status of the query.StreamingQuery.stop() Stop this streaming query.StreamingQueryManager.active Returns a list of active queries associated with this SQL-

ContextStreamingQueryManager.awaitAnyTermination([. . . ])

Wait until any of the queries on the associated SQLCon-text has terminated since the creation of the context, orsince resetTerminated() was called.

StreamingQueryManager.get(id) Returns an active query from this SQLContext or throwsexception if an active query with this name doesn’t exist.

StreamingQueryManager.resetTerminated()

Forget about past terminated queries so thatawaitAnyTermination() can be used againto wait for new terminations.

pyspark.sql.streaming.StreamingQuery.awaitTermination

StreamingQuery.awaitTermination(timeout=None)Waits for the termination of this query, either by query.stop() or by an exception. If the query has termi-nated with an exception, then the exception will be thrown. If timeout is set, it returns whether the query hasterminated or not within the timeout seconds.

If the query has terminated, then all subsequent calls to this method will either return immediately (if the querywas terminated by stop()), or throw the exception immediately (if the query has terminated with exception).

throws StreamingQueryException, if this query has terminated with an exception

New in version 2.0.

276 Chapter 3. API Reference

Page 281: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.StreamingQuery.exception

StreamingQuery.exception()New in version 2.1.0.

Returns

StreamingQueryException the StreamingQueryException if the query was terminatedby an exception, or None.

pyspark.sql.streaming.StreamingQuery.explain

StreamingQuery.explain(extended=False)Prints the (logical and physical) plans to the console for debugging purpose.

New in version 2.1.0.

Parameters

extended [bool, optional] default False. If False, prints only the physical plan.

Examples

>>> sq = sdf.writeStream.format('memory').queryName('query_explain').start()>>> sq.processAllAvailable() # Wait a bit to generate the runtime plans.>>> sq.explain()== Physical Plan ==...>>> sq.explain(True)== Parsed Logical Plan ==...== Analyzed Logical Plan ==...== Optimized Logical Plan ==...== Physical Plan ==...>>> sq.stop()

pyspark.sql.streaming.StreamingQuery.id

property StreamingQuery.idReturns the unique id of this query that persists across restarts from checkpoint data. That is, this id is generatedwhen a query is started for the first time, and will be the same every time it is restarted from checkpoint data.There can only be one query with the same id active in a Spark cluster. Also see, runId.

New in version 2.0.

3.2. Structured Streaming 277

Page 282: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.StreamingQuery.isActive

property StreamingQuery.isActiveWhether this streaming query is currently active or not.

New in version 2.0.

pyspark.sql.streaming.StreamingQuery.lastProgress

property StreamingQuery.lastProgressReturns the most recent StreamingQueryProgress update of this streaming query or None if there wereno progress updates

New in version 2.1.0.

Returns

dict

pyspark.sql.streaming.StreamingQuery.name

property StreamingQuery.nameReturns the user-specified name of the query, or null if not specified. This name can be specified inthe org.apache.spark.sql.streaming.DataStreamWriter as dataframe.writeStream.queryName(“query”).start().This name, if set, must be unique across all active queries.

New in version 2.0.

pyspark.sql.streaming.StreamingQuery.processAllAvailable

StreamingQuery.processAllAvailable()Blocks until all available data in the source has been processed and committed to the sink. This method isintended for testing.

New in version 2.0.0.

Notes

In the case of continually arriving data, this method may block forever. Additionally, this method is onlyguaranteed to block until data that has been synchronously appended data to a stream source prior to invocation.(i.e. getOffset must immediately reflect the addition).

pyspark.sql.streaming.StreamingQuery.recentProgress

property StreamingQuery.recentProgressReturns an array of the most recent [[StreamingQueryProgress]] updates for this query. The num-ber of progress updates retained for each stream is configured by Spark session configurationspark.sql.streaming.numRecentProgressUpdates.

New in version 2.1.

278 Chapter 3. API Reference

Page 283: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.streaming.StreamingQuery.runId

property StreamingQuery.runIdReturns the unique id of this query that does not persist across restarts. That is, every query that is started (orrestarted from checkpoint) will have a different runId.

New in version 2.1.

pyspark.sql.streaming.StreamingQuery.status

property StreamingQuery.statusReturns the current status of the query.

New in version 2.1.

pyspark.sql.streaming.StreamingQuery.stop

StreamingQuery.stop()Stop this streaming query.

New in version 2.0.

pyspark.sql.streaming.StreamingQueryManager.active

property StreamingQueryManager.activeReturns a list of active queries associated with this SQLContext

New in version 2.0.0.

Examples

>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()>>> sqm = spark.streams>>> # get the list of active streaming queries>>> [q.name for q in sqm.active]['this_query']>>> sq.stop()

pyspark.sql.streaming.StreamingQueryManager.awaitAnyTermination

StreamingQueryManager.awaitAnyTermination(timeout=None)Wait until any of the queries on the associated SQLContext has terminated since the creation of the context, orsince resetTerminated() was called. If any query was terminated with an exception, then the exceptionwill be thrown. If timeout is set, it returns whether the query has terminated or not within the timeout seconds.

If a query has terminated, then subsequent calls to awaitAnyTermination() will either return immedi-ately (if the query was terminated by query.stop()), or throw the exception immediately (if the query wasterminated with exception). Use resetTerminated() to clear past terminations and wait for new termina-tions.

In the case where multiple queries have terminated since resetTermination() was called, if any query hasterminated with exception, then awaitAnyTermination() will throw any of the exception. For correctly

3.2. Structured Streaming 279

Page 284: pyspark Documentation

pyspark Documentation, Release master

documenting exceptions across multiple queries, users need to stop all of them after any of them terminates withexception, and then check the query.exception() for each query.

throws StreamingQueryException, if this query has terminated with an exception

New in version 2.0.

pyspark.sql.streaming.StreamingQueryManager.get

StreamingQueryManager.get(id)Returns an active query from this SQLContext or throws exception if an active query with this name doesn’texist.

New in version 2.0.0.

Examples

>>> sq = sdf.writeStream.format('memory').queryName('this_query').start()>>> sq.name'this_query'>>> sq = spark.streams.get(sq.id)>>> sq.isActiveTrue>>> sq = sqlContext.streams.get(sq.id)>>> sq.isActiveTrue>>> sq.stop()

pyspark.sql.streaming.StreamingQueryManager.resetTerminated

StreamingQueryManager.resetTerminated()Forget about past terminated queries so that awaitAnyTermination() can be used again to wait for newterminations.

New in version 2.0.0.

Examples

>>> spark.streams.resetTerminated()

3.3 ML

3.3.1 ML Pipeline APIs

Transformer() Abstract class for transformers that transform onedataset into another.

UnaryTransformer() Abstract class for transformers that take one input col-umn, apply transformation, and output the result as anew column.

continues on next page

280 Chapter 3. API Reference

Page 285: pyspark Documentation

pyspark Documentation, Release master

Table 53 – continued from previous pageEstimator() Abstract class for estimators that fit models to data.Model() Abstract class for models that are fitted by estimators.Predictor() Estimator for prediction tasks (regression and classifica-

tion).PredictionModel() Model for prediction tasks (regression and classifica-

tion).Pipeline(*args, **kwargs) A simple pipeline, which acts as an estimator.PipelineModel(stages) Represents a compiled pipeline with transformers and

fitted models.

Transformer

class pyspark.ml.TransformerAbstract class for transformers that transform one dataset into another.

New in version 1.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.

3.3. ML 281

Page 286: pyspark Documentation

pyspark Documentation, Release master

Attributes

params Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

282 Chapter 3. API Reference

Page 287: pyspark Documentation

pyspark Documentation, Release master

set(param, value)Sets a parameter in the embedded param map.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

UnaryTransformer

class pyspark.ml.UnaryTransformerAbstract class for transformers that take one input column, apply transformation, and output the result as a newcolumn.

New in version 2.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

createTransformFunc() Creates the transform function using the given parammap.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.continues on next page

3.3. ML 283

Page 288: pyspark Documentation

pyspark Documentation, Release master

Table 56 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.outputDataType() Returns the data type of the output column.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.transformSchema(schema)validateInputType(inputType) Validates the input type.

Attributes

inputColoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

abstract createTransformFunc()Creates the transform function using the given param map. The input param map already takes account ofthe embedded param map. So the param values should be determined solely by the input param map.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra

284 Chapter 3. API Reference

Page 289: pyspark Documentation

pyspark Documentation, Release master

values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

abstract outputDataType()Returns the data type of the output column.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

transformSchema(schema)

3.3. ML 285

Page 290: pyspark Documentation

pyspark Documentation, Release master

abstract validateInputType(inputType)Validates the input type. Throw an exception if it is invalid.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Estimator

class pyspark.ml.EstimatorAbstract class for estimators that fit models to data.

New in version 1.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.

286 Chapter 3. API Reference

Page 291: pyspark Documentation

pyspark Documentation, Release master

Attributes

params Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

3.3. ML 287

Page 292: pyspark Documentation

pyspark Documentation, Release master

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

set(param, value)Sets a parameter in the embedded param map.

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Model

class pyspark.ml.ModelAbstract class for models that are fitted by estimators.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

continues on next page

288 Chapter 3. API Reference

Page 293: pyspark Documentation

pyspark Documentation, Release master

Table 60 – continued from previous pageexplainParams() Returns the documentation of all params with their

optionally default values and user-supplied values.extractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.

Attributes

params Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

3.3. ML 289

Page 294: pyspark Documentation

pyspark Documentation, Release master

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

set(param, value)Sets a parameter in the embedded param map.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Predictor

class pyspark.ml.PredictorEstimator for prediction tasks (regression and classification).

290 Chapter 3. API Reference

Page 295: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLabelCol(value) Sets the value of labelCol.setPredictionCol(value) Sets the value of predictionCol.

Attributes

featuresCollabelColparams Returns all params ordered by name.predictionCol

3.3. ML 291

Page 296: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

292 Chapter 3. API Reference

Page 297: pyspark Documentation

pyspark Documentation, Release master

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFeaturesCol()Gets the value of featuresCol or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

3.3. ML 293

Page 298: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

PredictionModel

class pyspark.ml.PredictionModelModel for prediction tasks (regression and classification).

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.predict(value) Predict label for the given features.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.

294 Chapter 3. API Reference

Page 299: pyspark Documentation

pyspark Documentation, Release master

Attributes

featuresCollabelColnumFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getFeaturesCol()Gets the value of featuresCol or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

3.3. ML 295

Page 300: pyspark Documentation

pyspark Documentation, Release master

getPredictionCol()Gets the value of predictionCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

abstract predict(value)Predict label for the given features.

New in version 3.0.0.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

296 Chapter 3. API Reference

Page 301: pyspark Documentation

pyspark Documentation, Release master

Pipeline

class pyspark.ml.Pipeline(*args, **kwargs)A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which iseither an Estimator or a Transformer. When Pipeline.fit() is called, the stages are executed inorder. If a stage is an Estimator, its Estimator.fit() method will be called on the input dataset to fita model. Then the model, which is a transformer, will be used to transform the dataset as the input to the nextstage. If a stage is a Transformer, its Transformer.transform() method will be called to producethe dataset for the next stage. The fitted model from a Pipeline is a PipelineModel, which consists offitted models and transformers, corresponding to the pipeline stages. If stages is an empty list, the pipeline actsas an identity transformer.

New in version 1.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.getStages() Get pipeline stages.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setParams(self, \*[, stages]) Sets params for Pipeline.

continues on next page

3.3. ML 297

Page 302: pyspark Documentation

pyspark Documentation, Release master

Table 66 – continued from previous pagesetStages(value) Set pipeline stages.write() Returns an MLWriter instance for this ML instance.

Attributes

params Returns all params ordered by name.stages

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance.

New in version 1.4.0.

Parameters

extra [dict, optional] extra parameters

Returns

Pipeline new instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

298 Chapter 3. API Reference

Page 303: pyspark Documentation

pyspark Documentation, Release master

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getStages()Get pipeline stages.

New in version 1.3.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

New in version 2.0.0.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setParams(self, \*, stages=None)Sets params for Pipeline.

New in version 1.3.0.

setStages(value)Set pipeline stages.

New in version 1.3.0.

3.3. ML 299

Page 304: pyspark Documentation

pyspark Documentation, Release master

Parameters

value [list] of pyspark.ml.Transformer or pyspark.ml.Estimator

Returns

Pipeline the pipeline instance

write()Returns an MLWriter instance for this ML instance.

New in version 2.0.0.

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stages = Param(parent='undefined', name='stages', doc='a list of pipeline stages')

PipelineModel

class pyspark.ml.PipelineModel(stages)Represents a compiled pipeline with transformers and fitted models.

New in version 1.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.

continues on next page

300 Chapter 3. API Reference

Page 305: pyspark Documentation

pyspark Documentation, Release master

Table 68 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

params Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance.

New in version 1.4.0.

Parameters extra – extra parameters

Returns new instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 301

Page 306: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

New in version 2.0.0.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

New in version 2.0.0.

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

3.3.2 Parameters

Param(parent, name, doc[, typeConverter]) A param with self-contained documentation.Params() Components that take parameters.TypeConverters() Factory methods for common type conversion functions

for Param.typeConverter.

302 Chapter 3. API Reference

Page 307: pyspark Documentation

pyspark Documentation, Release master

Param

class pyspark.ml.param.Param(parent, name, doc, typeConverter=None)A param with self-contained documentation.

New in version 1.3.0.

Params

class pyspark.ml.param.ParamsComponents that take parameters. This also provides an internal param map to store parameter values attachedto the instance.

New in version 1.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.

3.3. ML 303

Page 308: pyspark Documentation

pyspark Documentation, Release master

Attributes

params Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

304 Chapter 3. API Reference

Page 309: pyspark Documentation

pyspark Documentation, Release master

set(param, value)Sets a parameter in the embedded param map.

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

TypeConverters

class pyspark.ml.param.TypeConvertersFactory methods for common type conversion functions for Param.typeConverter.

New in version 2.0.0.

Methods

identity(value) Dummy converter that just returns value.toBoolean(value) Convert a value to a boolean, if possible.toFloat(value) Convert a value to a float, if possible.toInt(value) Convert a value to an int, if possible.toList(value) Convert a value to a list, if possible.toListFloat(value) Convert a value to list of floats, if possible.toListInt(value) Convert a value to list of ints, if possible.toListListFloat(value) Convert a value to list of list of floats, if possible.toListString(value) Convert a value to list of strings, if possible.toMatrix(value) Convert a value to a MLlib Matrix, if possible.toString(value) Convert a value to a string, if possible.toVector(value) Convert a value to a MLlib Vector, if possible.

Methods Documentation

static identity(value)Dummy converter that just returns value.

static toBoolean(value)Convert a value to a boolean, if possible.

static toFloat(value)Convert a value to a float, if possible.

static toInt(value)Convert a value to an int, if possible.

static toList(value)Convert a value to a list, if possible.

static toListFloat(value)Convert a value to list of floats, if possible.

static toListInt(value)Convert a value to list of ints, if possible.

3.3. ML 305

Page 310: pyspark Documentation

pyspark Documentation, Release master

static toListListFloat(value)Convert a value to list of list of floats, if possible.

static toListString(value)Convert a value to list of strings, if possible.

static toMatrix(value)Convert a value to a MLlib Matrix, if possible.

static toString(value)Convert a value to a string, if possible.

static toVector(value)Convert a value to a MLlib Vector, if possible.

3.3.3 Feature

ANOVASelector(*args, **kwargs) ANOVA F-value Classification selector, which selectscontinuous features to use for predicting a categoricallabel.

ANOVASelectorModel([java_model]) Model fitted by ANOVASelector.Binarizer(*args, **kwargs) Binarize a column of continuous features given a thresh-

old.BucketedRandomProjectionLSH(*args,**kwargs)

LSH class for Euclidean distance metrics.

BucketedRandomProjectionLSHModel([java_model])Model fitted by BucketedRandomProjectionLSH ,where multiple random vectors are stored.

Bucketizer(*args, **kwargs) Maps a column of continuous features to a column offeature buckets.

ChiSqSelector(*args, **kwargs) Chi-Squared feature selection, which selects categoricalfeatures to use for predicting a categorical label.

ChiSqSelectorModel([java_model]) Model fitted by ChiSqSelector.CountVectorizer(*args, **kwargs) Extracts a vocabulary from document collections and

generates a CountVectorizerModel.CountVectorizerModel([java_model]) Model fitted by CountVectorizer.DCT(*args, **kwargs) A feature transformer that takes the 1D discrete cosine

transform of a real vector.ElementwiseProduct(*args, **kwargs) Outputs the Hadamard product (i.e., the element-wise

product) of each input vector with a provided “weight”vector.

FeatureHasher(*args, **kwargs) Feature hashing projects a set of categorical or numeri-cal features into a feature vector of specified dimension(typically substantially smaller than that of the originalfeature space).

FValueSelector(*args, **kwargs) F Value Regression feature selector, which selects con-tinuous features to use for predicting a continuous label.

FValueSelectorModel([java_model]) Model fitted by FValueSelector.HashingTF(*args, **kwargs) Maps a sequence of terms to their term frequencies us-

ing the hashing trick.IDF(*args, **kwargs) Compute the Inverse Document Frequency (IDF) given

a collection of documents.IDFModel([java_model]) Model fitted by IDF.

continues on next page

306 Chapter 3. API Reference

Page 311: pyspark Documentation

pyspark Documentation, Release master

Table 74 – continued from previous pageImputer(*args, **kwargs) Imputation estimator for completing missing values, ei-

ther using the mean or the median of the columns inwhich the missing values are located.

ImputerModel([java_model]) Model fitted by Imputer.IndexToString(*args, **kwargs) A pyspark.ml.base.Transformer that maps a

column of indices back to a new column of correspond-ing string values.

Interaction(*args, **kwargs) Implements the feature interaction transform.MaxAbsScaler(*args, **kwargs) Rescale each feature individually to range [-1, 1] by di-

viding through the largest maximum absolute value ineach feature.

MaxAbsScalerModel([java_model]) Model fitted by MaxAbsScaler.MinHashLSH(*args, **kwargs) LSH class for Jaccard distance.MinHashLSHModel([java_model]) Model produced by MinHashLSH , where where mul-

tiple hash functions are stored.MinMaxScaler(*args, **kwargs) Rescale each feature individually to a common range

[min, max] linearly using column summary statistics,which is also known as min-max normalization orRescaling.

MinMaxScalerModel([java_model]) Model fitted by MinMaxScaler.NGram(*args, **kwargs) A feature transformer that converts the input array of

strings into an array of n-grams.Normalizer(*args, **kwargs) Normalize a vector to have unit norm using the given

p-norm.OneHotEncoder(*args, **kwargs) A one-hot encoder that maps a column of category in-

dices to a column of binary vectors, with at most a sin-gle one-value per row that indicates the input categoryindex.

OneHotEncoderModel([java_model]) Model fitted by OneHotEncoder.PCA(*args, **kwargs) PCA trains a model to project vectors to a lower dimen-

sional space of the top k principal components.PCAModel([java_model]) Model fitted by PCA.PolynomialExpansion(*args, **kwargs) Perform feature expansion in a polynomial space.QuantileDiscretizer(*args, **kwargs) QuantileDiscretizer takes a column with con-

tinuous features and outputs a column with binned cat-egorical features.

RobustScaler(*args, **kwargs) RobustScaler removes the median and scales the dataaccording to the quantile range.

RobustScalerModel([java_model]) Model fitted by RobustScaler.RegexTokenizer(*args, **kwargs) A regex based tokenizer that extracts tokens either by

using the provided regex pattern (in Java dialect) to splitthe text (default) or repeatedly matching the regex (ifgaps is false).

RFormula(*args, **kwargs) Implements the transforms required for fitting a datasetagainst an R model formula.

RFormulaModel([java_model]) Model fitted by RFormula.SQLTransformer(*args, **kwargs) Implements the transforms which are defined by SQL

statement.StandardScaler(*args, **kwargs) Standardizes features by removing the mean and scaling

to unit variance using column summary statistics on thesamples in the training set.

continues on next page

3.3. ML 307

Page 312: pyspark Documentation

pyspark Documentation, Release master

Table 74 – continued from previous pageStandardScalerModel([java_model]) Model fitted by StandardScaler.StopWordsRemover(*args, **kwargs) A feature transformer that filters out stop words from

input.StringIndexer(*args, **kwargs) A label indexer that maps a string column of labels to

an ML column of label indices.StringIndexerModel([java_model]) Model fitted by StringIndexer.Tokenizer(*args, **kwargs) A tokenizer that converts the input string to lowercase

and then splits it by white spaces.VarianceThresholdSelector(*args, **kwargs) Feature selector that removes all low-variance features.VarianceThresholdSelectorModel([java_model])Model fitted by VarianceThresholdSelector.VectorAssembler(*args, **kwargs) A feature transformer that merges multiple columns into

a vector column.VectorIndexer(*args, **kwargs) Class for indexing categorical feature columns in a

dataset of Vector.VectorIndexerModel([java_model]) Model fitted by VectorIndexer.VectorSizeHint(*args, **kwargs) A feature transformer that adds size information to the

metadata of a vector column.VectorSlicer(*args, **kwargs) This class takes a feature vector and outputs a new fea-

ture vector with a subarray of the original features.Word2Vec(*args, **kwargs) Word2Vec trains a model of Map(String, Vector), i.e.Word2VecModel([java_model]) Model fitted by Word2Vec.

ANOVASelector

class pyspark.ml.feature.ANOVASelector(*args, **kwargs)ANOVA F-value Classification selector, which selects continuous features to use for predicting a categoricallabel. The selector supports different selection methods: numTopFeatures, percentile, fpr, fdr, fwe.

• numTopFeatures chooses a fixed number of top features according to a F value classification test.

• percentile is similar but chooses a fraction of all features instead of a fixed number.

• fpr chooses all features whose p-values are below a threshold, thus controlling the false positive rate ofselection.

• fdr uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below athreshold.

• fwe chooses all features whose p-values are below a threshold. The threshold is scaled by 1 / numFeatures,thus controlling the family-wise error rate of selection.

By default, the selection method is numTopFeatures, with the default number of top features set to 50.

New in version 3.1.0.

308 Chapter 3. API Reference

Page 313: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([1.7, 4.4, 7.6, 5.8, 9.6, 2.3]), 3.0),... (Vectors.dense([8.8, 7.3, 5.7, 7.3, 2.2, 4.1]), 2.0),... (Vectors.dense([1.2, 9.5, 2.5, 3.1, 8.7, 2.5]), 1.0),... (Vectors.dense([3.7, 9.2, 6.1, 4.1, 7.5, 3.8]), 2.0),... (Vectors.dense([8.9, 5.2, 7.8, 8.3, 5.2, 3.0]), 4.0),... (Vectors.dense([7.9, 8.5, 9.2, 4.0, 9.4, 2.1]), 4.0)],... ["features", "label"])>>> selector = ANOVASelector(numTopFeatures=1, outputCol="selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")ANOVASelectorModel...>>> model.transform(df).head().selectedFeaturesDenseVector([7.6])>>> model.selectedFeatures[2]>>> anovaSelectorPath = temp_path + "/anova-selector">>> selector.save(anovaSelectorPath)>>> loadedSelector = ANOVASelector.load(anovaSelectorPath)>>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()True>>> modelPath = temp_path + "/anova-selector-model">>> model.save(modelPath)>>> loadedModel = ANOVASelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

continues on next page

3.3. ML 309

Page 314: pyspark Documentation

pyspark Documentation, Release master

Table 75 – continued from previous pagefitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map

in paramMaps.getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFdr(value) Sets the value of fdr.setFeaturesCol(value) Sets the value of featuresCol.setFpr(value) Sets the value of fpr.setFwe(value) Sets the value of fwe.setLabelCol(value) Sets the value of labelCol.setNumTopFeatures(value) Sets the value of numTopFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numTopFeatures, . . . ]) Sets params for this ANOVASelector.setPercentile(value) Sets the value of percentile.setSelectorType(value) Sets the value of selectorType.write() Returns an MLWriter instance for this ML instance.

Attributes

fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.

continues on next page

310 Chapter 3. API Reference

Page 315: pyspark Documentation

pyspark Documentation, Release master

Table 76 – continued from previous pagepercentileselectorType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

3.3. ML 311

Page 316: pyspark Documentation

pyspark Documentation, Release master

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFdr()Gets the value of fdr or its default value.

New in version 2.2.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFpr()Gets the value of fpr or its default value.

New in version 2.1.0.

getFwe()Gets the value of fwe or its default value.

New in version 2.2.0.

getLabelCol()Gets the value of labelCol or its default value.

getNumTopFeatures()Gets the value of numTopFeatures or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPercentile()Gets the value of percentile or its default value.

New in version 2.1.0.

getSelectorType()Gets the value of selectorType or its default value.

New in version 2.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

312 Chapter 3. API Reference

Page 317: pyspark Documentation

pyspark Documentation, Release master

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFdr(value)Sets the value of fdr. Only applicable when selectorType = “fdr”.

New in version 2.2.0.

setFeaturesCol(value)Sets the value of featuresCol.

setFpr(value)Sets the value of fpr. Only applicable when selectorType = “fpr”.

New in version 2.1.0.

setFwe(value)Sets the value of fwe. Only applicable when selectorType = “fwe”.

New in version 2.2.0.

setLabelCol(value)Sets the value of labelCol.

setNumTopFeatures(value)Sets the value of numTopFeatures. Only applicable when selectorType = “numTopFeatures”.

New in version 2.0.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, numTopFeatures=50, featuresCol="features", outputCol=None, label-Col="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05,fwe=0.05)

Sets params for this ANOVASelector.

New in version 3.1.0.

setPercentile(value)Sets the value of percentile. Only applicable when selectorType = “percentile”.

New in version 2.1.0.

setSelectorType(value)Sets the value of selectorType.

New in version 2.1.0.

write()Returns an MLWriter instance for this ML instance.

3.3. ML 313

Page 318: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')

fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')

selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')

ANOVASelectorModel

class pyspark.ml.feature.ANOVASelectorModel(java_model=None)Model fitted by ANOVASelector.

New in version 3.1.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default

value.continues on next page

314 Chapter 3. API Reference

Page 319: pyspark Documentation

pyspark Documentation, Release master

Table 77 – continued from previous pagegetOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectedFeatures List of indices to select (filter).selectorType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

3.3. ML 315

Page 320: pyspark Documentation

pyspark Documentation, Release master

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getFdr()Gets the value of fdr or its default value.

New in version 2.2.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFpr()Gets the value of fpr or its default value.

New in version 2.1.0.

getFwe()Gets the value of fwe or its default value.

New in version 2.2.0.

getLabelCol()Gets the value of labelCol or its default value.

getNumTopFeatures()Gets the value of numTopFeatures or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPercentile()Gets the value of percentile or its default value.

New in version 2.1.0.

316 Chapter 3. API Reference

Page 321: pyspark Documentation

pyspark Documentation, Release master

getSelectorType()Gets the value of selectorType or its default value.

New in version 2.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 317

Page 322: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')

fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')

selectedFeaturesList of indices to select (filter).

New in version 2.0.0.

selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')

Binarizer

class pyspark.ml.feature.Binarizer(*args, **kwargs)Binarize a column of continuous features given a threshold. Since 3.0.0, Binarize can map multiple columnsat once by setting the inputCols parameter. Note that when both the inputCol and inputCols param-eters are set, an Exception will be thrown. The threshold parameter is used for single column usage, andthresholds is for multiple columns.

New in version 1.4.0.

Examples

>>> df = spark.createDataFrame([(0.5,)], ["values"])>>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features")>>> binarizer.setThreshold(1.0)Binarizer...>>> binarizer.setInputCol("values")Binarizer...>>> binarizer.setOutputCol("features")Binarizer...>>> binarizer.transform(df).head().features0.0>>> binarizer.setParams(outputCol="freqs").transform(df).head().freqs0.0>>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"}>>> binarizer.transform(df, params).head().vector1.0>>> binarizerPath = temp_path + "/binarizer">>> binarizer.save(binarizerPath)>>> loadedBinarizer = Binarizer.load(binarizerPath)

(continues on next page)

318 Chapter 3. API Reference

Page 323: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> loadedBinarizer.getThreshold() == binarizer.getThreshold()True>>> loadedBinarizer.transform(df).take(1) == binarizer.transform(df).take(1)True>>> df2 = spark.createDataFrame([(0.5, 0.3)], ["values1", "values2"])>>> binarizer2 = Binarizer(thresholds=[0.0, 1.0])>>> binarizer2.setInputCols(["values1", "values2"]).setOutputCols(["output1",→˓"output2"])Binarizer...>>> binarizer2.transform(df2).show()+-------+-------+-------+-------+|values1|values2|output1|output2|+-------+-------+-------+-------+| 0.5| 0.3| 1.0| 0.0|+-------+-------+-------+-------+...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getThreshold() Gets the value of threshold or its default value.getThresholds() Gets the value of thresholds or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.

continues on next page

3.3. ML 319

Page 324: pyspark Documentation

pyspark Documentation, Release master

Table 79 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, threshold, inputCol, . . . ]) Sets params for this Binarizer.setThreshold(value) Sets the value of threshold.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColinputColsoutputColoutputColsparams Returns all params ordered by name.thresholdthresholds

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

320 Chapter 3. API Reference

Page 325: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getThreshold()Gets the value of threshold or its default value.

getThresholds()Gets the value of thresholds or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

3.3. ML 321

Page 326: pyspark Documentation

pyspark Documentation, Release master

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

setParams(self, \*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, input-Cols=None, outputCols=None)

Sets params for this Binarizer.

New in version 1.4.0.

setThreshold(value)Sets the value of threshold.

New in version 1.4.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

threshold = Param(parent='undefined', name='threshold', doc='Param for threshold used to binarize continuous features. The features greater than the threshold will be binarized to 1.0. The features equal to or less than the threshold will be binarized to 0.0')

thresholds = Param(parent='undefined', name='thresholds', doc='Param for array of threshold used to binarize continuous features. This is for multiple columns input. If transforming multiple columns and thresholds is not set, but threshold is set, then threshold will be applied across all columns.')

322 Chapter 3. API Reference

Page 327: pyspark Documentation

pyspark Documentation, Release master

BucketedRandomProjectionLSH

class pyspark.ml.feature.BucketedRandomProjectionLSH(*args, **kwargs)LSH class for Euclidean distance metrics. The input is dense or sparse vectors, each of which represents a pointin the Euclidean distance space. The output will be vectors of configurable dimension. Hash values in the samedimension are calculated by the same hash function.

New in version 2.2.0.

Notes

• Stable Distributions in Wikipedia article on Locality-sensitive hashing

• Hashing for Similarity Search: A Survey

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.sql.functions import col>>> data = [(0, Vectors.dense([-1.0, -1.0 ]),),... (1, Vectors.dense([-1.0, 1.0 ]),),... (2, Vectors.dense([1.0, -1.0 ]),),... (3, Vectors.dense([1.0, 1.0]),)]>>> df = spark.createDataFrame(data, ["id", "features"])>>> brp = BucketedRandomProjectionLSH()>>> brp.setInputCol("features")BucketedRandomProjectionLSH...>>> brp.setOutputCol("hashes")BucketedRandomProjectionLSH...>>> brp.setSeed(12345)BucketedRandomProjectionLSH...>>> brp.setBucketLength(1.0)BucketedRandomProjectionLSH...>>> model = brp.fit(df)>>> model.getBucketLength()1.0>>> model.setOutputCol("hashes")BucketedRandomProjectionLSHModel...>>> model.transform(df).head()Row(id=0, features=DenseVector([-1.0, -1.0]), hashes=[DenseVector([-1.0])])>>> data2 = [(4, Vectors.dense([2.0, 2.0 ]),),... (5, Vectors.dense([2.0, 3.0 ]),),... (6, Vectors.dense([3.0, 2.0 ]),),... (7, Vectors.dense([3.0, 3.0]),)]>>> df2 = spark.createDataFrame(data2, ["id", "features"])>>> model.approxNearestNeighbors(df2, Vectors.dense([1.0, 2.0]), 1).collect()[Row(id=4, features=DenseVector([2.0, 2.0]), hashes=[DenseVector([1.0])],→˓distCol=1.0)]>>> model.approxSimilarityJoin(df, df2, 3.0, distCol="EuclideanDistance").select(... col("datasetA.id").alias("idA"),... col("datasetB.id").alias("idB"),... col("EuclideanDistance")).show()+---+---+-----------------+|idA|idB|EuclideanDistance|+---+---+-----------------+| 3| 6| 2.23606797749979|

(continues on next page)

3.3. ML 323

Page 328: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+---+---+-----------------+...>>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select(... col("datasetA.id").alias("idA"),... col("datasetB.id").alias("idB"),... col("EuclideanDistance")).show()+---+---+-----------------+|idA|idB|EuclideanDistance|+---+---+-----------------+| 3| 6| 2.23606797749979|+---+---+-----------------+...>>> brpPath = temp_path + "/brp">>> brp.save(brpPath)>>> brp2 = BucketedRandomProjectionLSH.load(brpPath)>>> brp2.getBucketLength() == brp.getBucketLength()True>>> modelPath = temp_path + "/brp-model">>> model.save(modelPath)>>> model2 = BucketedRandomProjectionLSHModel.load(modelPath)>>> model.transform(df).head().hashes == model2.transform(df).head().hashesTrue

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getBucketLength() Gets the value of bucketLength or its default value.getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.

continues on next page

324 Chapter 3. API Reference

Page 329: pyspark Documentation

pyspark Documentation, Release master

Table 81 – continued from previous pagegetSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBucketLength(value) Sets the value of bucketLength.setInputCol(value) Sets the value of inputCol.setNumHashTables(value) Sets the value of numHashTables.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this BucketedRandomProjec-

tionLSH.setSeed(value) Sets the value of seed.write() Returns an MLWriter instance for this ML instance.

Attributes

bucketLengthinputColnumHashTablesoutputColparams Returns all params ordered by name.seed

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

3.3. ML 325

Page 330: pyspark Documentation

pyspark Documentation, Release master

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getBucketLength()Gets the value of bucketLength or its default value.

New in version 2.2.0.

getInputCol()Gets the value of inputCol or its default value.

getNumHashTables()Gets the value of numHashTables or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

326 Chapter 3. API Reference

Page 331: pyspark Documentation

pyspark Documentation, Release master

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBucketLength(value)Sets the value of bucketLength.

New in version 2.2.0.

setInputCol(value)Sets the value of inputCol.

setNumHashTables(value)Sets the value of numHashTables.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inputCol=None, outputCol=None, seed=None, numHashTables=1, buck-etLength=None)

Sets params for this BucketedRandomProjectionLSH.

New in version 2.2.0.

setSeed(value)Sets the value of seed.

write()Returns an MLWriter instance for this ML instance.

3.3. ML 327

Page 332: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

bucketLength = Param(parent='undefined', name='bucketLength', doc='the length of each hash bucket, a larger bucket lowers the false negative rate.')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

BucketedRandomProjectionLSHModel

class pyspark.ml.feature.BucketedRandomProjectionLSHModel(java_model=None)Model fitted by BucketedRandomProjectionLSH , where multiple random vectors are stored. The vec-tors are normalized to be unit vectors and each vector is used in a hash function: ℎ𝑖(𝑥) = 𝑓𝑙𝑜𝑜𝑟(𝑟𝑖 ·𝑥/𝑏𝑢𝑐𝑘𝑒𝑡𝐿𝑒𝑛𝑔𝑡ℎ) where 𝑟𝑖 is the i-th random unit vector. The number of buckets will be (max L2 norm ofinput vectors) / bucketLength.

New in version 2.2.0.

Methods

approxNearestNeighbors(dataset, key, . . . [,. . . ])

Given a large dataset and an item, approximately findat most k items which have the closest distance to theitem.

approxSimilarityJoin(datasetA, datasetB,. . . )

Join two datasets to approximately find all pairs ofrows whose distance are smaller than the threshold.

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBucketLength() Gets the value of bucketLength or its default value.getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.continues on next page

328 Chapter 3. API Reference

Page 333: pyspark Documentation

pyspark Documentation, Release master

Table 83 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

bucketLengthinputColnumHashTablesoutputColparams Returns all params ordered by name.

Methods Documentation

approxNearestNeighbors(dataset, key, numNearestNeighbors, distCol='distCol')Given a large dataset and an item, approximately find at most k items which have the closest distance tothe item. If the outputCol is missing, the method will transform the data; if the outputCol exists, itwill use that. This allows caching of the transformed data when necessary.

Parameters

dataset [pyspark.sql.DataFrame] The dataset to search for nearest neighbors of thekey.

key [pyspark.ml.linalg.Vector] Feature vector representing the item to search for.

numNearestNeighbors [int] The maximum number of nearest neighbors.

distCol [str] Output column for storing the distance between each result row and the key.Use “distCol” as default value if it’s not specified.

Returns

pyspark.sql.DataFrame A dataset containing at most k items closest to the key. Acolumn “distCol” is added to show the distance between each row and the key.

3.3. ML 329

Page 334: pyspark Documentation

pyspark Documentation, Release master

Notes

This method is experimental and will likely change behavior in the next release.

approxSimilarityJoin(datasetA, datasetB, threshold, distCol='distCol')Join two datasets to approximately find all pairs of rows whose distance are smaller than the threshold. Ifthe outputCol is missing, the method will transform the data; if the outputCol exists, it will use that.This allows caching of the transformed data when necessary.

Parameters

datasetA [pyspark.sql.DataFrame] One of the datasets to join.

datasetB [pyspark.sql.DataFrame] Another dataset to join.

threshold [float] The threshold for the distance of row pairs.

distCol [str, optional] Output column for storing the distance between each pair of rows.Use “distCol” as default value if it’s not specified.

Returns

pyspark.sql.DataFrame A joined dataset containing pairs of rows. The original rowsare in columns “datasetA” and “datasetB”, and a column “distCol” is added to show thedistance between each pair.

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBucketLength()Gets the value of bucketLength or its default value.

330 Chapter 3. API Reference

Page 335: pyspark Documentation

pyspark Documentation, Release master

New in version 2.2.0.

getInputCol()Gets the value of inputCol or its default value.

getNumHashTables()Gets the value of numHashTables or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 331

Page 336: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

bucketLength = Param(parent='undefined', name='bucketLength', doc='the length of each hash bucket, a larger bucket lowers the false negative rate.')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Bucketizer

class pyspark.ml.feature.Bucketizer(*args, **kwargs)Maps a column of continuous features to a column of feature buckets. Since 3.0.0, Bucketizer can mapmultiple columns at once by setting the inputCols parameter. Note that when both the inputCol andinputCols parameters are set, an Exception will be thrown. The splits parameter is only used for singlecolumn usage, and splitsArray is for multiple columns.

New in version 1.4.0.

Examples

>>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),... (float("nan"), 1.0), (float("nan"), 0.0)]>>> df = spark.createDataFrame(values, ["values1", "values2"])>>> bucketizer = Bucketizer()>>> bucketizer.setSplits([-float("inf"), 0.5, 1.4, float("inf")])Bucketizer...>>> bucketizer.setInputCol("values1")Bucketizer...>>> bucketizer.setOutputCol("buckets")Bucketizer...>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))>>> bucketed.show(truncate=False)+-------+-------+|values1|buckets|+-------+-------+|0.1 |0.0 ||0.4 |0.0 ||1.2 |1.0 ||1.5 |2.0 ||NaN |3.0 ||NaN |3.0 |+-------+-------+...>>> bucketizer.setParams(outputCol="b").transform(df).head().b0.0>>> bucketizerPath = temp_path + "/bucketizer">>> bucketizer.save(bucketizerPath)>>> loadedBucketizer = Bucketizer.load(bucketizerPath)>>> loadedBucketizer.getSplits() == bucketizer.getSplits()

(continues on next page)

332 Chapter 3. API Reference

Page 337: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

True>>> loadedBucketizer.transform(df).take(1) == bucketizer.transform(df).take(1)True>>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()>>> len(bucketed)4>>> bucketizer2 = Bucketizer(splitsArray=... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf→˓")]],... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])>>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)>>> bucketed2.show(truncate=False)+-------+-------+--------+--------+|values1|values2|buckets1|buckets2|+-------+-------+--------+--------+|0.1 |0.0 |0.0 |0.0 ||0.4 |1.0 |0.0 |1.0 ||1.2 |1.3 |1.0 |1.0 ||1.5 |NaN |2.0 |2.0 ||NaN |1.0 |3.0 |1.0 ||NaN |0.0 |3.0 |0.0 |+-------+-------+--------+--------+...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getSplits() Gets the value of threshold or its default value.getSplitsArray() Gets the array of split points or its default value.

continues on next page

3.3. ML 333

Page 338: pyspark Documentation

pyspark Documentation, Release master

Table 85 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, splits, inputCol, . . . ]) Sets params for this Bucketizer.setSplits(value) Sets the value of splits.setSplitsArray(value) Sets the value of splitsArray .transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

handleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.splitssplitsArray

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

334 Chapter 3. API Reference

Page 339: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getSplits()Gets the value of threshold or its default value.

New in version 1.4.0.

getSplitsArray()Gets the array of split points or its default value.

New in version 3.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

3.3. ML 335

Page 340: pyspark Documentation

pyspark Documentation, Release master

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setHandleInvalid(value)Sets the value of handleInvalid.

setInputCol(value)Sets the value of inputCol.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

setParams(self, \*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", splitsAr-ray=None, inputCols=None, outputCols=None)

Sets params for this Bucketizer.

New in version 1.4.0.

setSplits(value)Sets the value of splits.

New in version 1.4.0.

setSplitsArray(value)Sets the value of splitsArray .

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

336 Chapter 3. API Reference

Page 341: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries containing NaN values. Values outside the splits will always be treated as errors. Options are 'skip' (filter out rows with invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special additional bucket). Note that in the multiple column case, the invalid handling is applied to all columns. That said for 'error' it will throw an error if any invalids are found in any column, for 'skip' it will skip rows with any invalids in any columns, etc.")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

splits = Param(parent='undefined', name='splits', doc='Split points for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, values outside the splits specified will be treated as errors.')

splitsArray = Param(parent='undefined', name='splitsArray', doc='The array of split points for mapping continuous features into buckets for multiple columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, values outside the splits specified will be treated as errors.')

ChiSqSelector

class pyspark.ml.feature.ChiSqSelector(*args, **kwargs)Chi-Squared feature selection, which selects categorical features to use for predicting a categorical label. Theselector supports different selection methods: numTopFeatures, percentile, fpr, fdr, fwe.

• numTopFeatures chooses a fixed number of top features according to a chi-squared test.

• percentile is similar but chooses a fraction of all features instead of a fixed number.

• fpr chooses all features whose p-values are below a threshold, thus controlling the false positive rate ofselection.

• fdr uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below athreshold.

• fwe chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures,thus controlling the family-wise error rate of selection.

By default, the selection method is numTopFeatures, with the default number of top features set to 50.

New in version 2.0.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],... ["features", "label"])>>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")ChiSqSelectorModel...>>> model.transform(df).head().selectedFeaturesDenseVector([18.0])>>> model.selectedFeatures

(continues on next page)

3.3. ML 337

Page 342: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

[2]>>> chiSqSelectorPath = temp_path + "/chi-sq-selector">>> selector.save(chiSqSelectorPath)>>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath)>>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()True>>> modelPath = temp_path + "/chi-sq-selector-model">>> model.save(modelPath)>>> loadedModel = ChiSqSelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.

continues on next page

338 Chapter 3. API Reference

Page 343: pyspark Documentation

pyspark Documentation, Release master

Table 87 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFdr(value) Sets the value of fdr.setFeaturesCol(value) Sets the value of featuresCol.setFpr(value) Sets the value of fpr.setFwe(value) Sets the value of fwe.setLabelCol(value) Sets the value of labelCol.setNumTopFeatures(value) Sets the value of numTopFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numTopFeatures, . . . ]) Sets params for this ChiSqSelector.setPercentile(value) Sets the value of percentile.setSelectorType(value) Sets the value of selectorType.write() Returns an MLWriter instance for this ML instance.

Attributes

fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectorType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

3.3. ML 339

Page 344: pyspark Documentation

pyspark Documentation, Release master

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFdr()Gets the value of fdr or its default value.

New in version 2.2.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFpr()Gets the value of fpr or its default value.

340 Chapter 3. API Reference

Page 345: pyspark Documentation

pyspark Documentation, Release master

New in version 2.1.0.

getFwe()Gets the value of fwe or its default value.

New in version 2.2.0.

getLabelCol()Gets the value of labelCol or its default value.

getNumTopFeatures()Gets the value of numTopFeatures or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPercentile()Gets the value of percentile or its default value.

New in version 2.1.0.

getSelectorType()Gets the value of selectorType or its default value.

New in version 2.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFdr(value)Sets the value of fdr. Only applicable when selectorType = “fdr”.

New in version 2.2.0.

3.3. ML 341

Page 346: pyspark Documentation

pyspark Documentation, Release master

setFeaturesCol(value)Sets the value of featuresCol.

setFpr(value)Sets the value of fpr. Only applicable when selectorType = “fpr”.

New in version 2.1.0.

setFwe(value)Sets the value of fwe. Only applicable when selectorType = “fwe”.

New in version 2.2.0.

setLabelCol(value)Sets the value of labelCol.

setNumTopFeatures(value)Sets the value of numTopFeatures. Only applicable when selectorType = “numTopFeatures”.

New in version 2.0.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, numTopFeatures=50, featuresCol="features", outputCol=None, label-Col="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05,fwe=0.05)

Sets params for this ChiSqSelector.

New in version 2.0.0.

setPercentile(value)Sets the value of percentile. Only applicable when selectorType = “percentile”.

New in version 2.1.0.

setSelectorType(value)Sets the value of selectorType.

New in version 2.1.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')

fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')

342 Chapter 3. API Reference

Page 347: pyspark Documentation

pyspark Documentation, Release master

selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')

ChiSqSelectorModel

class pyspark.ml.feature.ChiSqSelectorModel(java_model=None)Model fitted by ChiSqSelector.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.continues on next page

3.3. ML 343

Page 348: pyspark Documentation

pyspark Documentation, Release master

Table 89 – continued from previous pageset(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectedFeatures List of indices to select (filter).selectorType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

344 Chapter 3. API Reference

Page 349: pyspark Documentation

pyspark Documentation, Release master

Returns

dict merged param map

getFdr()Gets the value of fdr or its default value.

New in version 2.2.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFpr()Gets the value of fpr or its default value.

New in version 2.1.0.

getFwe()Gets the value of fwe or its default value.

New in version 2.2.0.

getLabelCol()Gets the value of labelCol or its default value.

getNumTopFeatures()Gets the value of numTopFeatures or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPercentile()Gets the value of percentile or its default value.

New in version 2.1.0.

getSelectorType()Gets the value of selectorType or its default value.

New in version 2.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 345

Page 350: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')

fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')

selectedFeaturesList of indices to select (filter).

New in version 2.0.0.

selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')

346 Chapter 3. API Reference

Page 351: pyspark Documentation

pyspark Documentation, Release master

CountVectorizer

class pyspark.ml.feature.CountVectorizer(*args, **kwargs)Extracts a vocabulary from document collections and generates a CountVectorizerModel.

New in version 1.6.0.

Examples

>>> df = spark.createDataFrame(... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],... ["label", "raw"])>>> cv = CountVectorizer()>>> cv.setInputCol("raw")CountVectorizer...>>> cv.setOutputCol("vectors")CountVectorizer...>>> model = cv.fit(df)>>> model.setInputCol("raw")CountVectorizerModel...>>> model.transform(df).show(truncate=False)+-----+---------------+-------------------------+|label|raw |vectors |+-----+---------------+-------------------------+|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])||1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|+-----+---------------+-------------------------+...>>> sorted(model.vocabulary) == ['a', 'b', 'c']True>>> countVectorizerPath = temp_path + "/count-vectorizer">>> cv.save(countVectorizerPath)>>> loadedCv = CountVectorizer.load(countVectorizerPath)>>> loadedCv.getMinDF() == cv.getMinDF()True>>> loadedCv.getMinTF() == cv.getMinTF()True>>> loadedCv.getVocabSize() == cv.getVocabSize()True>>> modelPath = temp_path + "/count-vectorizer-model">>> model.save(modelPath)>>> loadedModel = CountVectorizerModel.load(modelPath)>>> loadedModel.vocabulary == model.vocabularyTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True>>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],... inputCol="raw", outputCol="vectors")>>> fromVocabModel.transform(df).show(truncate=False)+-----+---------------+-------------------------+|label|raw |vectors |+-----+---------------+-------------------------+|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])||1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|+-----+---------------+-------------------------+...

3.3. ML 347

Page 352: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getBinary() Gets the value of binary or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxDF() Gets the value of maxDF or its default value.getMinDF() Gets the value of minDF or its default value.getMinTF() Gets the value of minTF or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getVocabSize() Gets the value of vocabSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBinary(value) Sets the value of binary .setInputCol(value) Sets the value of inputCol.setMaxDF(value) Sets the value of maxDF.setMinDF(value) Sets the value of minDF.setMinTF(value) Sets the value of minTF.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minTF, minDF, maxDF, . . . ]) Set the params for the CountVectorizersetVocabSize(value) Sets the value of vocabSize.

continues on next page

348 Chapter 3. API Reference

Page 353: pyspark Documentation

pyspark Documentation, Release master

Table 91 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.

Attributes

binaryinputColmaxDFminDFminTFoutputColparams Returns all params ordered by name.vocabSize

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

3.3. ML 349

Page 354: pyspark Documentation

pyspark Documentation, Release master

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getBinary()Gets the value of binary or its default value.

New in version 2.0.0.

getInputCol()Gets the value of inputCol or its default value.

getMaxDF()Gets the value of maxDF or its default value.

New in version 2.4.0.

getMinDF()Gets the value of minDF or its default value.

New in version 1.6.0.

getMinTF()Gets the value of minTF or its default value.

New in version 1.6.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getVocabSize()Gets the value of vocabSize or its default value.

New in version 1.6.0.

350 Chapter 3. API Reference

Page 355: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBinary(value)Sets the value of binary .

New in version 2.0.0.

setInputCol(value)Sets the value of inputCol.

setMaxDF(value)Sets the value of maxDF.

New in version 2.4.0.

setMinDF(value)Sets the value of minDF.

New in version 1.6.0.

setMinTF(value)Sets the value of minTF.

New in version 1.6.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,inputCol=None, outputCol=None)

Set the params for the CountVectorizer

New in version 1.6.0.

setVocabSize(value)Sets the value of vocabSize.

New in version 1.6.0.

write()Returns an MLWriter instance for this ML instance.

3.3. ML 351

Page 356: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

binary = Param(parent='undefined', name='binary', doc='Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. Default False')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxDF = Param(parent='undefined', name='maxDF', doc='Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. Default (2^63) - 1')

minDF = Param(parent='undefined', name='minDF', doc='Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents. Default 1.0')

minTF = Param(parent='undefined', name='minTF', doc="Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). Note that the parameter is only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0")

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

vocabSize = Param(parent='undefined', name='vocabSize', doc='max size of the vocabulary. Default 1 << 18.')

CountVectorizerModel

class pyspark.ml.feature.CountVectorizerModel(java_model=None)Model fitted by CountVectorizer.

New in version 1.6.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

from_vocabulary(vocabulary, inputCol[, . . . ]) Construct the model directly from a vocabulary listof strings, requires an active SparkContext.

getBinary() Gets the value of binary or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxDF() Gets the value of maxDF or its default value.getMinDF() Gets the value of minDF or its default value.getMinTF() Gets the value of minTF or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.

continues on next page

352 Chapter 3. API Reference

Page 357: pyspark Documentation

pyspark Documentation, Release master

Table 93 – continued from previous pagegetParam(paramName) Gets a param by its name.getVocabSize() Gets the value of vocabSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBinary(value) Sets the value of binary .setInputCol(value) Sets the value of inputCol.setMinTF(value) Sets the value of minTF.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

binaryinputColmaxDFminDFminTFoutputColparams Returns all params ordered by name.vocabSizevocabulary An array of terms in the vocabulary.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

3.3. ML 353

Page 358: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

classmethod from_vocabulary(vocabulary, inputCol, outputCol=None, minTF=None, bi-nary=None)

Construct the model directly from a vocabulary list of strings, requires an active SparkContext.

New in version 2.4.0.

getBinary()Gets the value of binary or its default value.

New in version 2.0.0.

getInputCol()Gets the value of inputCol or its default value.

getMaxDF()Gets the value of maxDF or its default value.

New in version 2.4.0.

getMinDF()Gets the value of minDF or its default value.

New in version 1.6.0.

getMinTF()Gets the value of minTF or its default value.

New in version 1.6.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getVocabSize()Gets the value of vocabSize or its default value.

New in version 1.6.0.

hasDefault(param)Checks whether a param has a default value.

354 Chapter 3. API Reference

Page 359: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBinary(value)Sets the value of binary .

New in version 2.4.0.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setMinTF(value)Sets the value of minTF.

New in version 2.4.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 355

Page 360: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

binary = Param(parent='undefined', name='binary', doc='Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. Default False')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxDF = Param(parent='undefined', name='maxDF', doc='Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. Default (2^63) - 1')

minDF = Param(parent='undefined', name='minDF', doc='Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents. Default 1.0')

minTF = Param(parent='undefined', name='minTF', doc="Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). Note that the parameter is only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0")

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

vocabSize = Param(parent='undefined', name='vocabSize', doc='max size of the vocabulary. Default 1 << 18.')

vocabularyAn array of terms in the vocabulary.

New in version 1.6.0.

DCT

class pyspark.ml.feature.DCT(*args, **kwargs)A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero padding is performedon the input vector. It returns a real vector of the same length representing the DCT. The return vector is scaledsuch that the transform matrix is unitary (aka scaled DCT-II).

New in version 1.6.0.

Notes

More information on Wikipedia.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])>>> dct = DCT( )>>> dct.setInverse(False)DCT...>>> dct.setInputCol("vec")DCT...>>> dct.setOutputCol("resultVec")DCT...>>> df2 = dct.transform(df1)>>> df2.head().resultVecDenseVector([10.969..., -0.707..., -2.041...])>>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").→˓transform(df2)>>> df3.head().origVecDenseVector([5.0, 8.0, 6.0])>>> dctPath = temp_path + "/dct"

(continues on next page)

356 Chapter 3. API Reference

Page 361: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> dct.save(dctPath)>>> loadedDtc = DCT.load(dctPath)>>> loadedDtc.transform(df1).take(1) == dct.transform(df1).take(1)True>>> loadedDtc.getInverse()False

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getInverse() Gets the value of inverse or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInverse(value) Sets the value of inverse.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inverse, inputCol, . . . ]) Sets params for this DCT.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

3.3. ML 357

Page 362: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColinverseoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getInverse()Gets the value of inverse or its default value.

New in version 1.6.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

358 Chapter 3. API Reference

Page 363: pyspark Documentation

pyspark Documentation, Release master

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setInverse(value)Sets the value of inverse.

New in version 1.6.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inverse=False, inputCol=None, outputCol=None)Sets params for this DCT.

New in version 1.6.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 359

Page 364: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inverse = Param(parent='undefined', name='inverse', doc='Set transformer to perform inverse DCT, default False.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

ElementwiseProduct

class pyspark.ml.feature.ElementwiseProduct(*args, **kwargs)Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a provided “weight”vector. In other words, it scales each column of the dataset by a scalar multiplier.

New in version 1.5.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"])>>> ep = ElementwiseProduct()>>> ep.setScalingVec(Vectors.dense([1.0, 2.0, 3.0]))ElementwiseProduct...>>> ep.setInputCol("values")ElementwiseProduct...>>> ep.setOutputCol("eprod")ElementwiseProduct...>>> ep.transform(df).head().eprodDenseVector([2.0, 2.0, 9.0])>>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().→˓eprodDenseVector([4.0, 3.0, 15.0])>>> elementwiseProductPath = temp_path + "/elementwise-product">>> ep.save(elementwiseProductPath)>>> loadedEp = ElementwiseProduct.load(elementwiseProductPath)>>> loadedEp.getScalingVec() == ep.getScalingVec()True>>> loadedEp.transform(df).take(1) == ep.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

continues on next page

360 Chapter 3. API Reference

Page 365: pyspark Documentation

pyspark Documentation, Release master

Table 97 – continued from previous pageexplainParams() Returns the documentation of all params with their

optionally default values and user-supplied values.extractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getScalingVec() Gets the value of scalingVec or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, scalingVec, inputCol, . . . ]) Sets params for this ElementwiseProduct.setScalingVec(value) Sets the value of scalingVec.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColoutputColparams Returns all params ordered by name.scalingVec

3.3. ML 361

Page 366: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getScalingVec()Gets the value of scalingVec or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

362 Chapter 3. API Reference

Page 367: pyspark Documentation

pyspark Documentation, Release master

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, scalingVec=None, inputCol=None, outputCol=None)Sets params for this ElementwiseProduct.

New in version 1.5.0.

setScalingVec(value)Sets the value of scalingVec.

New in version 2.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

scalingVec = Param(parent='undefined', name='scalingVec', doc='Vector for hadamard product.')

3.3. ML 363

Page 368: pyspark Documentation

pyspark Documentation, Release master

FeatureHasher

class pyspark.ml.feature.FeatureHasher(*args, **kwargs)Feature hashing projects a set of categorical or numerical features into a feature vector of specified dimension(typically substantially smaller than that of the original feature space). This is done using the hashing trick(https://en.wikipedia.org/wiki/Feature_hashing) to map features to indices in the feature vector.

The FeatureHasher transformer operates on multiple columns. Each column may contain either numeric orcategorical features. Behavior and handling of column data types is as follows:

• Numeric columns: For numeric features, the hash value of the column name is used to map the featurevalue to its index in the feature vector. By default, numeric features are not treated as categorical (evenwhen they are integers). To treat them as categorical, specify the relevant columns in categoricalCols.

• String columns: For categorical features, the hash value of the string “column_name=value” is used tomap to the vector index, with an indicator value of 1.0. Thus, categorical features are “one-hot”encoded (similarly to using OneHotEncoder with dropLast=false).

• Boolean columns: Boolean values are treated in the same way as string columns. That is, boolean featuresare represented as “column_name=true” or “column_name=false”, with an indicator value of 1.0.

Null (missing) values are ignored (implicitly zero in the resulting feature vector).

Since a simple modulo is used to transform the hash function to a vector index, it is advisable to use a power oftwo as the numFeatures parameter; otherwise the features will not be mapped evenly to the vector indices.

New in version 2.3.0.

Examples

>>> data = [(2.0, True, "1", "foo"), (3.0, False, "2", "bar")]>>> cols = ["real", "bool", "stringNum", "string"]>>> df = spark.createDataFrame(data, cols)>>> hasher = FeatureHasher()>>> hasher.setInputCols(cols)FeatureHasher...>>> hasher.setOutputCol("features")FeatureHasher...>>> hasher.transform(df).head().featuresSparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})>>> hasher.setCategoricalCols(["real"]).transform(df).head().featuresSparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})>>> hasherPath = temp_path + "/hasher">>> hasher.save(hasherPath)>>> loadedHasher = FeatureHasher.load(hasherPath)>>> loadedHasher.getNumFeatures() == hasher.getNumFeatures()True>>> loadedHasher.transform(df).head().features == hasher.transform(df).head().→˓featuresTrue

364 Chapter 3. API Reference

Page 369: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCategoricalCols() Gets the value of binary or its default value.getInputCols() Gets the value of inputCols or its default value.getNumFeatures() Gets the value of numFeatures or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCategoricalCols(value) Sets the value of categoricalCols.setInputCols(value) Sets the value of inputCols.setNumFeatures(value) Sets the value of numFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numFeatures, . . . ]) Sets params for this FeatureHasher.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

3.3. ML 365

Page 370: pyspark Documentation

pyspark Documentation, Release master

Attributes

categoricalColsinputColsnumFeaturesoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCategoricalCols()Gets the value of binary or its default value.

New in version 2.3.0.

getInputCols()Gets the value of inputCols or its default value.

getNumFeatures()Gets the value of numFeatures or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

366 Chapter 3. API Reference

Page 371: pyspark Documentation

pyspark Documentation, Release master

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCategoricalCols(value)Sets the value of categoricalCols.

New in version 2.3.0.

setInputCols(value)Sets the value of inputCols.

setNumFeatures(value)Sets the value of numFeatures.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, numFeatures=1 << 18, inputCols=None, outputCol=None, categorical-Cols=None)

Sets params for this FeatureHasher.

New in version 2.3.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

3.3. ML 367

Page 372: pyspark Documentation

pyspark Documentation, Release master

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

categoricalCols = Param(parent='undefined', name='categoricalCols', doc='numeric columns to treat as categorical')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

numFeatures = Param(parent='undefined', name='numFeatures', doc='Number of features. Should be greater than 0.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

FValueSelector

class pyspark.ml.feature.FValueSelector(*args, **kwargs)F Value Regression feature selector, which selects continuous features to use for predicting a continuous label.The selector supports different selection methods: numTopFeatures, percentile, fpr, fdr, fwe.

• numTopFeatures chooses a fixed number of top features according to a F value regression test.

• percentile is similar but chooses a fraction of all features instead of a fixed number.

• fpr chooses all features whose p-values are below a threshold, thus controlling the false positive rate ofselection.

• fdr uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below athreshold.

• fwe chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures,thus controlling the family-wise error rate of selection.

By default, the selection method is numTopFeatures, with the default number of top features set to 50.

New in version 3.1.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]), 4.6),... (Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]), 6.6),... (Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]), 5.1),... (Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]), 7.6),... (Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]), 9.0),... (Vectors.dense([8.0, 9.0, 6.0, 4.0, 0.0, 0.0]), 9.0)],... ["features", "label"])>>> selector = FValueSelector(numTopFeatures=1, outputCol="selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")FValueSelectorModel...>>> model.transform(df).head().selectedFeatures

(continues on next page)

368 Chapter 3. API Reference

Page 373: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

DenseVector([0.0])>>> model.selectedFeatures[2]>>> fvalueSelectorPath = temp_path + "/fvalue-selector">>> selector.save(fvalueSelectorPath)>>> loadedSelector = FValueSelector.load(fvalueSelectorPath)>>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()True>>> modelPath = temp_path + "/fvalue-selector-model">>> model.save(modelPath)>>> loadedModel = FValueSelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.

continues on next page

3.3. ML 369

Page 374: pyspark Documentation

pyspark Documentation, Release master

Table 101 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFdr(value) Sets the value of fdr.setFeaturesCol(value) Sets the value of featuresCol.setFpr(value) Sets the value of fpr.setFwe(value) Sets the value of fwe.setLabelCol(value) Sets the value of labelCol.setNumTopFeatures(value) Sets the value of numTopFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numTopFeatures, . . . ]) Sets params for this FValueSelector.setPercentile(value) Sets the value of percentile.setSelectorType(value) Sets the value of selectorType.write() Returns an MLWriter instance for this ML instance.

Attributes

fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectorType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

370 Chapter 3. API Reference

Page 375: pyspark Documentation

pyspark Documentation, Release master

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFdr()Gets the value of fdr or its default value.

New in version 2.2.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

3.3. ML 371

Page 376: pyspark Documentation

pyspark Documentation, Release master

getFpr()Gets the value of fpr or its default value.

New in version 2.1.0.

getFwe()Gets the value of fwe or its default value.

New in version 2.2.0.

getLabelCol()Gets the value of labelCol or its default value.

getNumTopFeatures()Gets the value of numTopFeatures or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPercentile()Gets the value of percentile or its default value.

New in version 2.1.0.

getSelectorType()Gets the value of selectorType or its default value.

New in version 2.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFdr(value)Sets the value of fdr. Only applicable when selectorType = “fdr”.

372 Chapter 3. API Reference

Page 377: pyspark Documentation

pyspark Documentation, Release master

New in version 2.2.0.

setFeaturesCol(value)Sets the value of featuresCol.

setFpr(value)Sets the value of fpr. Only applicable when selectorType = “fpr”.

New in version 2.1.0.

setFwe(value)Sets the value of fwe. Only applicable when selectorType = “fwe”.

New in version 2.2.0.

setLabelCol(value)Sets the value of labelCol.

setNumTopFeatures(value)Sets the value of numTopFeatures. Only applicable when selectorType = “numTopFeatures”.

New in version 2.0.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, numTopFeatures=50, featuresCol="features", outputCol=None, label-Col="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05,fwe=0.05)

Sets params for this FValueSelector.

New in version 3.1.0.

setPercentile(value)Sets the value of percentile. Only applicable when selectorType = “percentile”.

New in version 2.1.0.

setSelectorType(value)Sets the value of selectorType.

New in version 2.1.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')

fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

3.3. ML 373

Page 378: pyspark Documentation

pyspark Documentation, Release master

percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')

selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')

FValueSelectorModel

class pyspark.ml.feature.FValueSelectorModel(java_model=None)Model fitted by FValueSelector.

New in version 3.1.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFdr() Gets the value of fdr or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFpr() Gets the value of fpr or its default value.getFwe() Gets the value of fwe or its default value.getLabelCol() Gets the value of labelCol or its default value.getNumTopFeatures() Gets the value of numTopFeatures or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPercentile() Gets the value of percentile or its default value.getSelectorType() Gets the value of selectorType or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.

continues on next page

374 Chapter 3. API Reference

Page 379: pyspark Documentation

pyspark Documentation, Release master

Table 103 – continued from previous pagesave(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

fdrfeaturesColfprfwelabelColnumTopFeaturesoutputColparams Returns all params ordered by name.percentileselectedFeatures List of indices to select (filter).selectorType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

3.3. ML 375

Page 380: pyspark Documentation

pyspark Documentation, Release master

extra [dict, optional] extra param values

Returns

dict merged param map

getFdr()Gets the value of fdr or its default value.

New in version 2.2.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFpr()Gets the value of fpr or its default value.

New in version 2.1.0.

getFwe()Gets the value of fwe or its default value.

New in version 2.2.0.

getLabelCol()Gets the value of labelCol or its default value.

getNumTopFeatures()Gets the value of numTopFeatures or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPercentile()Gets the value of percentile or its default value.

New in version 2.1.0.

getSelectorType()Gets the value of selectorType or its default value.

New in version 2.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

376 Chapter 3. API Reference

Page 381: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

fdr = Param(parent='undefined', name='fdr', doc='The upper bound of the expected false discovery rate.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fpr = Param(parent='undefined', name='fpr', doc='The highest p-value for features to be kept.')

fwe = Param(parent='undefined', name='fwe', doc='The upper bound of the expected family-wise error rate.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numTopFeatures = Param(parent='undefined', name='numTopFeatures', doc='Number of features that selector will select, ordered by ascending p-value. If the number of features is < numTopFeatures, then this will select all features.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

percentile = Param(parent='undefined', name='percentile', doc='Percentile of features that selector will select, ordered by ascending p-value.')

selectedFeaturesList of indices to select (filter).

New in version 2.0.0.

selectorType = Param(parent='undefined', name='selectorType', doc='The selector type. Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.')

3.3. ML 377

Page 382: pyspark Documentation

pyspark Documentation, Release master

HashingTF

class pyspark.ml.feature.HashingTF(*args, **kwargs)Maps a sequence of terms to their term frequencies using the hashing trick. Currently we use Austin Appleby’sMurmurHash 3 algorithm (MurmurHash3_x86_32) to calculate the hash code value for the term object. Since asimple modulo is used to transform the hash function to a column index, it is advisable to use a power of two asthe numFeatures parameter; otherwise the features will not be mapped evenly to the columns.

New in version 1.3.0.

Examples

>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"])>>> hashingTF = HashingTF(inputCol="words", outputCol="features")>>> hashingTF.setNumFeatures(10)HashingTF...>>> hashingTF.transform(df).head().featuresSparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqsSparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}>>> hashingTF.transform(df, params).head().vectorSparseVector(5, {0: 1.0, 2: 1.0, 3: 1.0})>>> hashingTFPath = temp_path + "/hashing-tf">>> hashingTF.save(hashingTFPath)>>> loadedHashingTF = HashingTF.load(hashingTFPath)>>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures()True>>> loadedHashingTF.transform(df).take(1) == hashingTF.transform(df).take(1)True>>> hashingTF.indexOf("b")5

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBinary() Gets the value of binary or its default value.getInputCol() Gets the value of inputCol or its default value.

continues on next page

378 Chapter 3. API Reference

Page 383: pyspark Documentation

pyspark Documentation, Release master

Table 105 – continued from previous pagegetNumFeatures() Gets the value of numFeatures or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.indexOf(term) Returns the index of the input term.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBinary(value) Sets the value of binary .setInputCol(value) Sets the value of inputCol.setNumFeatures(value) Sets the value of numFeatures.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, numFeatures, binary, . . . ]) Sets params for this HashingTF.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

binaryinputColnumFeaturesoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

3.3. ML 379

Page 384: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBinary()Gets the value of binary or its default value.

New in version 2.0.0.

getInputCol()Gets the value of inputCol or its default value.

getNumFeatures()Gets the value of numFeatures or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

indexOf(term)Returns the index of the input term.

New in version 3.0.0.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

380 Chapter 3. API Reference

Page 385: pyspark Documentation

pyspark Documentation, Release master

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBinary(value)Sets the value of binary .

New in version 2.0.0.

setInputCol(value)Sets the value of inputCol.

setNumFeatures(value)Sets the value of numFeatures.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None)Sets params for this HashingTF.

New in version 1.3.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

binary = Param(parent='undefined', name='binary', doc='If True, all non zero counts are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. Default False.')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

numFeatures = Param(parent='undefined', name='numFeatures', doc='Number of features. Should be greater than 0.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

3.3. ML 381

Page 386: pyspark Documentation

pyspark Documentation, Release master

IDF

class pyspark.ml.feature.IDF(*args, **kwargs)Compute the Inverse Document Frequency (IDF) given a collection of documents.

New in version 1.4.0.

Examples

>>> from pyspark.ml.linalg import DenseVector>>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),),... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"])>>> idf = IDF(minDocFreq=3)>>> idf.setInputCol("tf")IDF...>>> idf.setOutputCol("idf")IDF...>>> model = idf.fit(df)>>> model.setOutputCol("idf")IDFModel...>>> model.getMinDocFreq()3>>> model.idfDenseVector([0.0, 0.0])>>> model.docFreq[0, 3]>>> model.numDocs == df.count()True>>> model.transform(df).head().idfDenseVector([0.0, 0.0])>>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqsDenseVector([0.0, 0.0])>>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"}>>> idf.fit(df, params).transform(df).head().vectorDenseVector([0.2877, 0.0])>>> idfPath = temp_path + "/idf">>> idf.save(idfPath)>>> loadedIdf = IDF.load(idfPath)>>> loadedIdf.getMinDocFreq() == idf.getMinDocFreq()True>>> modelPath = temp_path + "/idf-model">>> model.save(modelPath)>>> loadedModel = IDFModel.load(modelPath)>>> loadedModel.transform(df).head().idf == model.transform(df).head().idfTrue

382 Chapter 3. API Reference

Page 387: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getMinDocFreq() Gets the value of minDocFreq or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMinDocFreq(value) Sets the value of minDocFreq .setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minDocFreq, inputCol, . . . ]) Sets params for this IDF.write() Returns an MLWriter instance for this ML instance.

3.3. ML 383

Page 388: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColminDocFreqoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

384 Chapter 3. API Reference

Page 389: pyspark Documentation

pyspark Documentation, Release master

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getMinDocFreq()Gets the value of minDocFreq or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setMinDocFreq(value)Sets the value of minDocFreq .

3.3. ML 385

Page 390: pyspark Documentation

pyspark Documentation, Release master

New in version 1.4.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, minDocFreq=0, inputCol=None, outputCol=None)Sets params for this IDF.

New in version 1.4.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

minDocFreq = Param(parent='undefined', name='minDocFreq', doc='minimum number of documents in which a term should appear for filtering')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

IDFModel

class pyspark.ml.feature.IDFModel(java_model=None)Model fitted by IDF.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getMinDocFreq() Gets the value of minDocFreq or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.

continues on next page

386 Chapter 3. API Reference

Page 391: pyspark Documentation

pyspark Documentation, Release master

Table 109 – continued from previous pagegetParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

docFreq Returns the document frequency.idf Returns the IDF vector.inputColminDocFreqnumDocs Returns number of documents evaluated to compute

idfoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

3.3. ML 387

Page 392: pyspark Documentation

pyspark Documentation, Release master

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getMinDocFreq()Gets the value of minDocFreq or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

388 Chapter 3. API Reference

Page 393: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

docFreqReturns the document frequency.

New in version 3.0.0.

idfReturns the IDF vector.

New in version 2.0.0.

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

minDocFreq = Param(parent='undefined', name='minDocFreq', doc='minimum number of documents in which a term should appear for filtering')

numDocsReturns number of documents evaluated to compute idf

New in version 3.0.0.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Imputer

class pyspark.ml.feature.Imputer(*args, **kwargs)Imputation estimator for completing missing values, either using the mean or the median of the columns inwhich the missing values are located. The input columns should be of numeric type. Currently Imputer does notsupport categorical features and possibly creates incorrect values for a categorical feature.

Note that the mean/median/mode value is computed after filtering out missing values. All Null values in theinput columns are treated as missing, and so are also imputed. For computing median, pyspark.sql.DataFrame.approxQuantile() is used with a relative error of 0.001.

New in version 2.2.0.

3.3. ML 389

Page 394: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float(→˓"nan"), 3.0),... (4.0, 4.0), (5.0, 5.0)], ["a", "b"])>>> imputer = Imputer()>>> imputer.setInputCols(["a", "b"])Imputer...>>> imputer.setOutputCols(["out_a", "out_b"])Imputer...>>> imputer.getRelativeError()0.001>>> model = imputer.fit(df)>>> model.setInputCols(["a", "b"])ImputerModel...>>> model.getStrategy()'mean'>>> model.surrogateDF.show()+---+---+| a| b|+---+---+|3.0|4.0|+---+---+...>>> model.transform(df).show()+---+---+-----+-----+| a| b|out_a|out_b|+---+---+-----+-----+|1.0|NaN| 1.0| 4.0||2.0|NaN| 2.0| 4.0||NaN|3.0| 3.0| 3.0|...>>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).→˓show()+---+---+-----+-----+| a| b|out_a|out_b|+---+---+-----+-----+|1.0|NaN| 4.0| NaN|...>>> df1 = spark.createDataFrame([(1.0,), (2.0,), (float("nan"),), (4.0,), (5.0,)],→˓ ["a"])>>> imputer1 = Imputer(inputCol="a", outputCol="out_a")>>> model1 = imputer1.fit(df1)>>> model1.surrogateDF.show()+---+| a|+---+|3.0|+---+...>>> model1.transform(df1).show()+---+-----+| a|out_a|+---+-----+|1.0| 1.0||2.0| 2.0||NaN| 3.0|

(continues on next page)

390 Chapter 3. API Reference

Page 395: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

...>>> imputer1.setStrategy("median").setMissingValue(1.0).fit(df1).transform(df1).→˓show()+---+-----+| a|out_a|+---+-----+|1.0| 4.0|...>>> df2 = spark.createDataFrame([(float("nan"),), (float("nan"),), (3.0,), (4.0,),→˓ (5.0,)],... ["b"])>>> imputer2 = Imputer(inputCol="b", outputCol="out_b")>>> model2 = imputer2.fit(df2)>>> model2.surrogateDF.show()+---+| b|+---+|4.0|+---+...>>> model2.transform(df2).show()+---+-----+| b|out_b|+---+-----+|NaN| 4.0||NaN| 4.0||3.0| 3.0|...>>> imputer2.setStrategy("median").setMissingValue(1.0).fit(df2).transform(df2).→˓show()+---+-----+| b|out_b|+---+-----+|NaN| NaN|...>>> imputerPath = temp_path + "/imputer">>> imputer.save(imputerPath)>>> loadedImputer = Imputer.load(imputerPath)>>> loadedImputer.getStrategy() == imputer.getStrategy()True>>> loadedImputer.getMissingValue()1.0>>> modelPath = temp_path + "/imputer-model">>> model.save(modelPath)>>> loadedModel = ImputerModel.load(modelPath)>>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_aTrue

3.3. ML 391

Page 396: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getMissingValue() Gets the value of missingValue or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.getStrategy() Gets the value of strategy or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setMissingValue(value) Sets the value of missingValue.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, strategy, . . . ]) Sets params for this Imputer.setRelativeError(value) Sets the value of relativeError.

continues on next page

392 Chapter 3. API Reference

Page 397: pyspark Documentation

pyspark Documentation, Release master

Table 111 – continued from previous pagesetStrategy(value) Sets the value of strategy .write() Returns an MLWriter instance for this ML instance.

Attributes

inputColinputColsmissingValueoutputColoutputColsparams Returns all params ordered by name.relativeErrorstrategy

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

3.3. ML 393

Page 398: pyspark Documentation

pyspark Documentation, Release master

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getMissingValue()Gets the value of missingValue or its default value.

New in version 2.2.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getRelativeError()Gets the value of relativeError or its default value.

getStrategy()Gets the value of strategy or its default value.

New in version 2.2.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

394 Chapter 3. API Reference

Page 399: pyspark Documentation

pyspark Documentation, Release master

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setInputCols(value)Sets the value of inputCols.

New in version 2.2.0.

setMissingValue(value)Sets the value of missingValue.

New in version 2.2.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

setOutputCols(value)Sets the value of outputCols.

New in version 2.2.0.

setParams(self, \*, strategy="mean", missingValue=float("nan"), inputCols=None, outputCols=None,inputCol=None, outputCol=None, relativeError=0.001)

Sets params for this Imputer.

New in version 2.2.0.

setRelativeError(value)Sets the value of relativeError.

New in version 3.0.0.

setStrategy(value)Sets the value of strategy .

New in version 2.2.0.

write()Returns an MLWriter instance for this ML instance.

3.3. ML 395

Page 400: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

missingValue = Param(parent='undefined', name='missingValue', doc='The placeholder for the missing values. All occurrences of missingValue will be imputed.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')

strategy = Param(parent='undefined', name='strategy', doc='strategy for imputation. If mean, then replace missing values using the mean value of the feature. If median, then replace missing values using the median value of the feature. If mode, then replace missing using the most frequent value of the feature.')

ImputerModel

class pyspark.ml.feature.ImputerModel(java_model=None)Model fitted by Imputer.

New in version 2.2.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getMissingValue() Gets the value of missingValue or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.

continues on next page

396 Chapter 3. API Reference

Page 401: pyspark Documentation

pyspark Documentation, Release master

Table 113 – continued from previous pagegetStrategy() Gets the value of strategy or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColinputColsmissingValueoutputColoutputColsparams Returns all params ordered by name.relativeErrorstrategysurrogateDF Returns a DataFrame containing inputCols and their

corresponding surrogates, which are used to replacethe missing values in the input DataFrame.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

3.3. ML 397

Page 402: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getMissingValue()Gets the value of missingValue or its default value.

New in version 2.2.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getRelativeError()Gets the value of relativeError or its default value.

getStrategy()Gets the value of strategy or its default value.

New in version 2.2.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

398 Chapter 3. API Reference

Page 403: pyspark Documentation

pyspark Documentation, Release master

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

missingValue = Param(parent='undefined', name='missingValue', doc='The placeholder for the missing values. All occurrences of missingValue will be imputed.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

3.3. ML 399

Page 404: pyspark Documentation

pyspark Documentation, Release master

relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')

strategy = Param(parent='undefined', name='strategy', doc='strategy for imputation. If mean, then replace missing values using the mean value of the feature. If median, then replace missing values using the median value of the feature. If mode, then replace missing using the most frequent value of the feature.')

surrogateDFReturns a DataFrame containing inputCols and their corresponding surrogates, which are used to replacethe missing values in the input DataFrame.

New in version 2.2.0.

IndexToString

class pyspark.ml.feature.IndexToString(*args, **kwargs)A pyspark.ml.base.Transformer that maps a column of indices back to a new column of correspond-ing string values. The index-string mapping is either from the ML attributes of the input column, or fromuser-supplied labels (which take precedence over ML attributes).

New in version 1.6.0.

See also:

StringIndexer for converting categorical values into category indices

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getLabels() Gets the value of labels or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.

continues on next page

400 Chapter 3. API Reference

Page 405: pyspark Documentation

pyspark Documentation, Release master

Table 115 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setLabels(value) Sets the value of labels.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this IndexToString.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputCollabelsoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

3.3. ML 401

Page 406: pyspark Documentation

pyspark Documentation, Release master

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getLabels()Gets the value of labels or its default value.

New in version 1.6.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setLabels(value)Sets the value of labels.

New in version 1.6.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inputCol=None, outputCol=None, labels=None)Sets params for this IndexToString.

New in version 1.6.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

402 Chapter 3. API Reference

Page 407: pyspark Documentation

pyspark Documentation, Release master

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

labels = Param(parent='undefined', name='labels', doc='Optional array of labels specifying index-string mapping. If not provided or if empty, then metadata from inputCol is used instead.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Interaction

class pyspark.ml.feature.Interaction(*args, **kwargs)Implements the feature interaction transform. This transformer takes in Double and Vector type columns andoutputs a flattened vector of their feature interactions. To handle interaction, we first one-hot encode any nominalfeatures. Then, a vector of the feature cross-products is produced.

For example, given the input feature values Double(2) and Vector(3, 4), the output would be Vector(6, 8) if allinput features were numeric. If the first feature was instead nominal with four categories, the output would thenbe Vector(0, 0, 0, 0, 3, 4, 0, 0).

New in version 3.0.0.

Examples

>>> df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"])>>> interaction = Interaction()>>> interaction.setInputCols(["a", "b"])Interaction...>>> interaction.setOutputCol("ab")Interaction...>>> interaction.transform(df).show()+---+---+-----+| a| b| ab|+---+---+-----+|0.0|1.0|[0.0]||2.0|3.0|[6.0]|+---+---+-----+...>>> interactionPath = temp_path + "/interaction">>> interaction.save(interactionPath)

(continues on next page)

3.3. ML 403

Page 408: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> loadedInteraction = Interaction.load(interactionPath)>>> loadedInteraction.transform(df).head().ab == interaction.transform(df).head().→˓abTrue

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCols, outputCol]) Sets params for this Interaction.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

404 Chapter 3. API Reference

Page 409: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColsoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 405

Page 410: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

setParams(self, \*, inputCols=None, outputCol=None)Sets params for this Interaction.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

406 Chapter 3. API Reference

Page 411: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

MaxAbsScaler

class pyspark.ml.feature.MaxAbsScaler(*args, **kwargs)Rescale each feature individually to range [-1, 1] by dividing through the largest maximum absolute value ineach feature. It does not shift/center the data, and thus does not destroy any sparsity.

New in version 2.0.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)],→˓ ["a"])>>> maScaler = MaxAbsScaler(outputCol="scaled")>>> maScaler.setInputCol("a")MaxAbsScaler...>>> model = maScaler.fit(df)>>> model.setOutputCol("scaledOutput")MaxAbsScalerModel...>>> model.transform(df).show()+-----+------------+| a|scaledOutput|+-----+------------+|[1.0]| [0.5]||[2.0]| [1.0]|+-----+------------+...>>> scalerPath = temp_path + "/max-abs-scaler">>> maScaler.save(scalerPath)>>> loadedMAScaler = MaxAbsScaler.load(scalerPath)>>> loadedMAScaler.getInputCol() == maScaler.getInputCol()True>>> loadedMAScaler.getOutputCol() == maScaler.getOutputCol()True>>> modelPath = temp_path + "/max-abs-scaler-model">>> model.save(modelPath)>>> loadedModel = MaxAbsScalerModel.load(modelPath)>>> loadedModel.maxAbs == model.maxAbsTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

3.3. ML 407

Page 412: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol]) Sets params for this MaxAbsScaler.write() Returns an MLWriter instance for this ML instance.

408 Chapter 3. API Reference

Page 413: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

3.3. ML 409

Page 414: pyspark Documentation

pyspark Documentation, Release master

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inputCol=None, outputCol=None)Sets params for this MaxAbsScaler.

New in version 2.0.0.

410 Chapter 3. API Reference

Page 415: pyspark Documentation

pyspark Documentation, Release master

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

MaxAbsScalerModel

class pyspark.ml.feature.MaxAbsScalerModel(java_model=None)Model fitted by MaxAbsScaler.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.

continues on next page

3.3. ML 411

Page 416: pyspark Documentation

pyspark Documentation, Release master

Table 121 – continued from previous pagesave(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColmaxAbs Max Abs vector.outputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

412 Chapter 3. API Reference

Page 417: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 413

Page 418: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxAbsMax Abs vector.

New in version 2.0.0.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

MinHashLSH

class pyspark.ml.feature.MinHashLSH(*args, **kwargs)LSH class for Jaccard distance. The input can be dense or sparse vectors, but it is more efficient if it is sparse.For example, Vectors.sparse(10, [(2, 1.0), (3, 1.0), (5, 1.0)]) means there are 10 elements in the space. Thisset contains elements 2, 3, and 5. Also, any input vector must have at least 1 non-zero index, and all non-zerovalues are treated as binary “1” values.

New in version 2.2.0.

Notes

See Wikipedia on MinHash

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.sql.functions import col>>> data = [(0, Vectors.sparse(6, [0, 1, 2], [1.0, 1.0, 1.0]),),... (1, Vectors.sparse(6, [2, 3, 4], [1.0, 1.0, 1.0]),),... (2, Vectors.sparse(6, [0, 2, 4], [1.0, 1.0, 1.0]),)]>>> df = spark.createDataFrame(data, ["id", "features"])>>> mh = MinHashLSH()>>> mh.setInputCol("features")MinHashLSH...>>> mh.setOutputCol("hashes")MinHashLSH...>>> mh.setSeed(12345)MinHashLSH...>>> model = mh.fit(df)>>> model.setInputCol("features")MinHashLSHModel...>>> model.transform(df).head()Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}),→˓hashes=[DenseVector([6179668...>>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),),... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)]>>> df2 = spark.createDataFrame(data2, ["id", "features"])>>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0])

(continues on next page)

414 Chapter 3. API Reference

Page 419: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.approxNearestNeighbors(df2, key, 1).collect()[Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}),→˓hashes=[DenseVector([6179668...>>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select(... col("datasetA.id").alias("idA"),... col("datasetB.id").alias("idB"),... col("JaccardDistance")).show()+---+---+---------------+|idA|idB|JaccardDistance|+---+---+---------------+| 0| 5| 0.5|| 1| 4| 0.5|+---+---+---------------+...>>> mhPath = temp_path + "/mh">>> mh.save(mhPath)>>> mh2 = MinHashLSH.load(mhPath)>>> mh2.getOutputCol() == mh.getOutputCol()True>>> modelPath = temp_path + "/mh-model">>> model.save(modelPath)>>> model2 = MinHashLSHModel.load(modelPath)>>> model.transform(df).head().hashes == model2.transform(df).head().hashesTrue

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.

continues on next page

3.3. ML 415

Page 420: pyspark Documentation

pyspark Documentation, Release master

Table 123 – continued from previous pagegetParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setNumHashTables(value) Sets the value of numHashTables.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this MinHashLSH.setSeed(value) Sets the value of seed.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColnumHashTablesoutputColparams Returns all params ordered by name.seed

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

416 Chapter 3. API Reference

Page 421: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getNumHashTables()Gets the value of numHashTables or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

3.3. ML 417

Page 422: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setNumHashTables(value)Sets the value of numHashTables.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inputCol=None, outputCol=None, seed=None, numHashTables=1)Sets params for this MinHashLSH.

New in version 2.2.0.

setSeed(value)Sets the value of seed.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

418 Chapter 3. API Reference

Page 423: pyspark Documentation

pyspark Documentation, Release master

MinHashLSHModel

class pyspark.ml.feature.MinHashLSHModel(java_model=None)Model produced by MinHashLSH , where where multiple hash functions are stored. Each hash function ispicked from the following family of hash functions, where 𝑎𝑖 and 𝑏𝑖 are randomly chosen integers less thanprime: ℎ𝑖(𝑥) = ((𝑥 ·𝑎𝑖+ 𝑏𝑖) mod 𝑝𝑟𝑖𝑚𝑒) This hash family is approximately min-wise independent accordingto the reference.

New in version 2.2.0.

Notes

See Tom Bohman, Colin Cooper, and Alan Frieze. “Min-wise independent linear permutations.” ElectronicJournal of Combinatorics 7 (2000): R26.

Methods

approxNearestNeighbors(dataset, key, . . . [,. . . ])

Given a large dataset and an item, approximately findat most k items which have the closest distance to theitem.

approxSimilarityJoin(datasetA, datasetB,. . . )

Join two datasets to approximately find all pairs ofrows whose distance are smaller than the threshold.

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getNumHashTables() Gets the value of numHashTables or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).continues on next page

3.3. ML 419

Page 424: pyspark Documentation

pyspark Documentation, Release master

Table 125 – continued from previous pageread() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColnumHashTablesoutputColparams Returns all params ordered by name.

Methods Documentation

approxNearestNeighbors(dataset, key, numNearestNeighbors, distCol='distCol')Given a large dataset and an item, approximately find at most k items which have the closest distance tothe item. If the outputCol is missing, the method will transform the data; if the outputCol exists, itwill use that. This allows caching of the transformed data when necessary.

Parameters

dataset [pyspark.sql.DataFrame] The dataset to search for nearest neighbors of thekey.

key [pyspark.ml.linalg.Vector] Feature vector representing the item to search for.

numNearestNeighbors [int] The maximum number of nearest neighbors.

distCol [str] Output column for storing the distance between each result row and the key.Use “distCol” as default value if it’s not specified.

Returns

pyspark.sql.DataFrame A dataset containing at most k items closest to the key. Acolumn “distCol” is added to show the distance between each row and the key.

Notes

This method is experimental and will likely change behavior in the next release.

approxSimilarityJoin(datasetA, datasetB, threshold, distCol='distCol')Join two datasets to approximately find all pairs of rows whose distance are smaller than the threshold. Ifthe outputCol is missing, the method will transform the data; if the outputCol exists, it will use that.This allows caching of the transformed data when necessary.

Parameters

datasetA [pyspark.sql.DataFrame] One of the datasets to join.

datasetB [pyspark.sql.DataFrame] Another dataset to join.

420 Chapter 3. API Reference

Page 425: pyspark Documentation

pyspark Documentation, Release master

threshold [float] The threshold for the distance of row pairs.

distCol [str, optional] Output column for storing the distance between each pair of rows.Use “distCol” as default value if it’s not specified.

Returns

pyspark.sql.DataFrame A joined dataset containing pairs of rows. The original rowsare in columns “datasetA” and “datasetB”, and a column “distCol” is added to show thedistance between each pair.

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getNumHashTables()Gets the value of numHashTables or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 421

Page 426: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

numHashTables = Param(parent='undefined', name='numHashTables', doc='number of hash tables, where increasing number of hash tables lowers the false negative rate, and decreasing it improves the running performance.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

422 Chapter 3. API Reference

Page 427: pyspark Documentation

pyspark Documentation, Release master

MinMaxScaler

class pyspark.ml.feature.MinMaxScaler(*args, **kwargs)Rescale each feature individually to a common range [min, max] linearly using column summary statistics,which is also known as min-max normalization or Rescaling. The rescaled value for feature E is calculated as,

Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min

For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min)

New in version 1.6.0.

Notes

Since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVectoreven for sparse input.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)],→˓ ["a"])>>> mmScaler = MinMaxScaler(outputCol="scaled")>>> mmScaler.setInputCol("a")MinMaxScaler...>>> model = mmScaler.fit(df)>>> model.setOutputCol("scaledOutput")MinMaxScalerModel...>>> model.originalMinDenseVector([0.0])>>> model.originalMaxDenseVector([2.0])>>> model.transform(df).show()+-----+------------+| a|scaledOutput|+-----+------------+|[0.0]| [0.0]||[2.0]| [1.0]|+-----+------------+...>>> minMaxScalerPath = temp_path + "/min-max-scaler">>> mmScaler.save(minMaxScalerPath)>>> loadedMMScaler = MinMaxScaler.load(minMaxScalerPath)>>> loadedMMScaler.getMin() == mmScaler.getMin()True>>> loadedMMScaler.getMax() == mmScaler.getMax()True>>> modelPath = temp_path + "/min-max-scaler-model">>> model.save(modelPath)>>> loadedModel = MinMaxScalerModel.load(modelPath)>>> loadedModel.originalMin == model.originalMinTrue>>> loadedModel.originalMax == model.originalMaxTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

3.3. ML 423

Page 428: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getMax() Gets the value of max or its default value.getMin() Gets the value of min or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMax(value) Sets the value of max.setMin(value) Sets the value of min.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, min, max, inputCol, . . . ]) Sets params for this MinMaxScaler.write() Returns an MLWriter instance for this ML instance.

424 Chapter 3. API Reference

Page 429: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColmaxminoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

3.3. ML 425

Page 430: pyspark Documentation

pyspark Documentation, Release master

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getMax()Gets the value of max or its default value.

New in version 1.6.0.

getMin()Gets the value of min or its default value.

New in version 1.6.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

426 Chapter 3. API Reference

Page 431: pyspark Documentation

pyspark Documentation, Release master

setInputCol(value)Sets the value of inputCol.

setMax(value)Sets the value of max.

New in version 1.6.0.

setMin(value)Sets the value of min.

New in version 1.6.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, min=0.0, max=1.0, inputCol=None, outputCol=None)Sets params for this MinMaxScaler.

New in version 1.6.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

max = Param(parent='undefined', name='max', doc='Upper bound of the output feature range')

min = Param(parent='undefined', name='min', doc='Lower bound of the output feature range')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

MinMaxScalerModel

class pyspark.ml.feature.MinMaxScalerModel(java_model=None)Model fitted by MinMaxScaler.

New in version 1.6.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

3.3. ML 427

Page 432: pyspark Documentation

pyspark Documentation, Release master

Table 129 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getMax() Gets the value of max or its default value.getMin() Gets the value of min or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMax(value) Sets the value of max.setMin(value) Sets the value of min.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColmaxminoriginalMax Max value for each original column during fitting.originalMin Min value for each original column during fitting.outputColparams Returns all params ordered by name.

428 Chapter 3. API Reference

Page 433: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getMax()Gets the value of max or its default value.

New in version 1.6.0.

getMin()Gets the value of min or its default value.

New in version 1.6.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 429

Page 434: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setMax(value)Sets the value of max.

New in version 3.0.0.

setMin(value)Sets the value of min.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

430 Chapter 3. API Reference

Page 435: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

max = Param(parent='undefined', name='max', doc='Upper bound of the output feature range')

min = Param(parent='undefined', name='min', doc='Lower bound of the output feature range')

originalMaxMax value for each original column during fitting.

New in version 2.0.0.

originalMinMin value for each original column during fitting.

New in version 2.0.0.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

NGram

class pyspark.ml.feature.NGram(*args, **kwargs)A feature transformer that converts the input array of strings into an array of n-grams. Null values in the inputarray are ignored. It returns an array of n-grams where each n-gram is represented by a space-separated string ofwords. When the input is empty, an empty array is returned. When the input array length is less than n (numberof elements per n-gram), no n-grams are returned.

New in version 1.5.0.

Examples

>>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])>>> ngram = NGram(n=2)>>> ngram.setInputCol("inputTokens")NGram...>>> ngram.setOutputCol("nGrams")NGram...>>> ngram.transform(df).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b', 'b c', 'c d', 'd e'])>>> # Change n-gram length>>> ngram.setParams(n=4).transform(df).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b c d', 'b c d e'])>>> # Temporarily modify output column.>>> ngram.transform(df, {ngram.outputCol: "output"}).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], output=['a b c d', 'b c d e'])>>> ngram.transform(df).head()Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b c d', 'b c d e'])>>> # Must use keyword arguments to specify params.>>> ngram.setParams("text")Traceback (most recent call last):

...TypeError: Method setParams forces keyword arguments.>>> ngramPath = temp_path + "/ngram"

(continues on next page)

3.3. ML 431

Page 436: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> ngram.save(ngramPath)>>> loadedNGram = NGram.load(ngramPath)>>> loadedNGram.getN() == ngram.getN()True>>> loadedNGram.transform(df).take(1) == ngram.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getN () Gets the value of n or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setN (value) Sets the value of n.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, n, inputCol, outputCol]) Sets params for this NGram.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

432 Chapter 3. API Reference

Page 437: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColnoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getN()Gets the value of n or its default value.

New in version 1.5.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

3.3. ML 433

Page 438: pyspark Documentation

pyspark Documentation, Release master

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setN(value)Sets the value of n.

New in version 1.5.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, n=2, inputCol=None, outputCol=None)Sets params for this NGram.

New in version 1.5.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

434 Chapter 3. API Reference

Page 439: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

n = Param(parent='undefined', name='n', doc='number of elements per n-gram (>=1)')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Normalizer

class pyspark.ml.feature.Normalizer(*args, **kwargs)

Normalize a vector to have unit norm using the given p-norm.

New in version 1.4.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0})>>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense",→˓"sparse"])>>> normalizer = Normalizer(p=2.0)>>> normalizer.setInputCol("dense")Normalizer...>>> normalizer.setOutputCol("features")Normalizer...>>> normalizer.transform(df).head().featuresDenseVector([0.6, -0.8])>>> normalizer.setParams(inputCol="sparse", outputCol="freqs").transform(df).→˓head().freqsSparseVector(4, {1: 0.8, 3: 0.6})>>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.→˓outputCol: "vector"}>>> normalizer.transform(df, params).head().vectorDenseVector([0.4286, -0.5714])>>> normalizerPath = temp_path + "/normalizer">>> normalizer.save(normalizerPath)>>> loadedNormalizer = Normalizer.load(normalizerPath)>>> loadedNormalizer.getP() == normalizer.getP()True>>> loadedNormalizer.transform(df).take(1) == normalizer.transform(df).take(1)True

3.3. ML 435

Page 440: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getP() Gets the value of p or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setP(value) Sets the value of p.setParams(self, \*[, p, inputCol, outputCol]) Sets params for this Normalizer.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

436 Chapter 3. API Reference

Page 441: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColoutputColpparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getP()Gets the value of p or its default value.

New in version 1.4.0.

3.3. ML 437

Page 442: pyspark Documentation

pyspark Documentation, Release master

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

setP(value)Sets the value of p.

New in version 1.4.0.

setParams(self, \*, p=2.0, inputCol=None, outputCol=None)Sets params for this Normalizer.

New in version 1.4.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

438 Chapter 3. API Reference

Page 443: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

p = Param(parent='undefined', name='p', doc='the p norm value.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

OneHotEncoder

class pyspark.ml.feature.OneHotEncoder(*args, **kwargs)A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a singleone-value per row that indicates the input category index. For example with 5 categories, an input value of 2.0would map to an output vector of [0.0, 0.0, 1.0, 0.0]. The last category is not included by default (configurablevia dropLast), because it makes the vector entries sum up to one, and hence linearly dependent. So an inputvalue of 4.0 maps to [0.0, 0.0, 0.0, 0.0].

When handleInvalid is configured to ‘keep’, an extra “category” indicating invalid values is added as lastcategory. So when dropLast is true, invalid values are encoded as all-zeros vector.

New in version 2.3.0.

See also:

StringIndexer for converting categorical values into category indices

Notes

This is different from scikit-learn’s OneHotEncoder, which keeps all categories. The output vectors are sparse.

When encoding multi-column by using inputCols and outputCols params, input/output cols come inpairs, specified by the order in the arrays, and each pair is treated independently.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])>>> ohe = OneHotEncoder()>>> ohe.setInputCols(["input"])OneHotEncoder...>>> ohe.setOutputCols(["output"])OneHotEncoder...>>> model = ohe.fit(df)>>> model.setOutputCols(["output"])OneHotEncoderModel...>>> model.getHandleInvalid()'error'>>> model.transform(df).head().outputSparseVector(2, {0: 1.0})>>> single_col_ohe = OneHotEncoder(inputCol="input", outputCol="output")>>> single_col_model = single_col_ohe.fit(df)>>> single_col_model.transform(df).head().output

(continues on next page)

3.3. ML 439

Page 444: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

SparseVector(2, {0: 1.0})>>> ohePath = temp_path + "/ohe">>> ohe.save(ohePath)>>> loadedOHE = OneHotEncoder.load(ohePath)>>> loadedOHE.getInputCols() == ohe.getInputCols()True>>> modelPath = temp_path + "/ohe-model">>> model.save(modelPath)>>> loadedModel = OneHotEncoderModel.load(modelPath)>>> loadedModel.categorySizes == model.categorySizesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getDropLast() Gets the value of dropLast or its default value.getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.

continues on next page

440 Chapter 3. API Reference

Page 445: pyspark Documentation

pyspark Documentation, Release master

Table 135 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDropLast(value) Sets the value of dropLast.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, inputCols, outputCols, . . . ]) Sets params for this OneHotEncoder.write() Returns an MLWriter instance for this ML instance.

Attributes

dropLasthandleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra

3.3. ML 441

Page 446: pyspark Documentation

pyspark Documentation, Release master

values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getDropLast()Gets the value of dropLast or its default value.

New in version 2.3.0.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

442 Chapter 3. API Reference

Page 447: pyspark Documentation

pyspark Documentation, Release master

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDropLast(value)Sets the value of dropLast.

New in version 2.3.0.

setHandleInvalid(value)Sets the value of handleInvalid.

New in version 3.0.0.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

setParams(self, \*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, in-putCol=None, outputCol=None)

Sets params for this OneHotEncoder.

New in version 2.3.0.

3.3. ML 443

Page 448: pyspark Documentation

pyspark Documentation, Release master

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

dropLast = Param(parent='undefined', name='dropLast', doc='whether to drop the last category')

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data during transform(). Options are 'keep' (invalid data presented as an extra categorical feature) or error (throw an error). Note that this Param is only used during transform; during fitting, invalid data will result in an error.")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

OneHotEncoderModel

class pyspark.ml.feature.OneHotEncoderModel(java_model=None)Model fitted by OneHotEncoder.

New in version 2.3.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getDropLast() Gets the value of dropLast or its default value.getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.

continues on next page

444 Chapter 3. API Reference

Page 449: pyspark Documentation

pyspark Documentation, Release master

Table 137 – continued from previous pagegetParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDropLast(value) Sets the value of dropLast.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

categorySizes Original number of categories for each feature beingencoded.

dropLasthandleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

3.3. ML 445

Page 450: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getDropLast()Gets the value of dropLast or its default value.

New in version 2.3.0.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

446 Chapter 3. API Reference

Page 451: pyspark Documentation

pyspark Documentation, Release master

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDropLast(value)Sets the value of dropLast.

New in version 3.0.0.

setHandleInvalid(value)Sets the value of handleInvalid.

New in version 3.0.0.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 447

Page 452: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

categorySizesOriginal number of categories for each feature being encoded. The array contains one value for each inputcolumn, in order.

New in version 2.3.0.

dropLast = Param(parent='undefined', name='dropLast', doc='whether to drop the last category')

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data during transform(). Options are 'keep' (invalid data presented as an extra categorical feature) or error (throw an error). Note that this Param is only used during transform; during fitting, invalid data will result in an error.")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

PCA

class pyspark.ml.feature.PCA(*args, **kwargs)PCA trains a model to project vectors to a lower dimensional space of the top k principal components.

New in version 1.5.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]>>> df = spark.createDataFrame(data,["features"])>>> pca = PCA(k=2, inputCol="features")>>> pca.setOutputCol("pca_features")PCA...>>> model = pca.fit(df)>>> model.getK()2>>> model.setOutputCol("output")PCAModel...>>> model.transform(df).collect()[0].outputDenseVector([1.648..., -4.013...])>>> model.explainedVarianceDenseVector([0.794..., 0.205...])>>> pcaPath = temp_path + "/pca">>> pca.save(pcaPath)>>> loadedPca = PCA.load(pcaPath)>>> loadedPca.getK() == pca.getK()True>>> modelPath = temp_path + "/pca-model">>> model.save(modelPath)>>> loadedModel = PCAModel.load(modelPath)

(continues on next page)

448 Chapter 3. API Reference

Page 453: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> loadedModel.pc == model.pcTrue>>> loadedModel.explainedVariance == model.explainedVarianceTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getK() Gets the value of k or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setK(value) Sets the value of k.setOutputCol(value) Sets the value of outputCol.

continues on next page

3.3. ML 449

Page 454: pyspark Documentation

pyspark Documentation, Release master

Table 139 – continued from previous pagesetParams(self, \*[, k, inputCol, outputCol]) Set params for this PCA.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColkoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

450 Chapter 3. API Reference

Page 455: pyspark Documentation

pyspark Documentation, Release master

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getK()Gets the value of k or its default value.

New in version 1.5.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

3.3. ML 451

Page 456: pyspark Documentation

pyspark Documentation, Release master

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setK(value)Sets the value of k.

New in version 1.5.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, k=None, inputCol=None, outputCol=None)Set params for this PCA.

New in version 1.5.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

k = Param(parent='undefined', name='k', doc='the number of principal components')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

PCAModel

class pyspark.ml.feature.PCAModel(java_model=None)Model fitted by PCA. Transforms vectors to a lower dimensional space.

New in version 1.5.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

452 Chapter 3. API Reference

Page 457: pyspark Documentation

pyspark Documentation, Release master

Table 141 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getK() Gets the value of k or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

explainedVariance Returns a vector of proportions of variance explainedby each principal component.

inputColkoutputColparams Returns all params ordered by name.pc Returns a principal components Matrix.

3.3. ML 453

Page 458: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getK()Gets the value of k or its default value.

New in version 1.5.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

454 Chapter 3. API Reference

Page 459: pyspark Documentation

pyspark Documentation, Release master

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

explainedVarianceReturns a vector of proportions of variance explained by each principal component.

New in version 2.0.0.

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

k = Param(parent='undefined', name='k', doc='the number of principal components')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

pcReturns a principal components Matrix. Each column is one principal component.

New in version 2.0.0.

3.3. ML 455

Page 460: pyspark Documentation

pyspark Documentation, Release master

PolynomialExpansion

class pyspark.ml.feature.PolynomialExpansion(*args, **kwargs)Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, “In mathe-matics, an expansion of a product of sums expresses it as a sum of products by using the fact that multiplicationdistributes over addition”. Take a 2-variable feature vector as an example: (x, y), if we want to expand it withdegree 2, then we get (x, x * x, y, x * y, y * y).

New in version 1.4.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])>>> px = PolynomialExpansion(degree=2)>>> px.setInputCol("dense")PolynomialExpansion...>>> px.setOutputCol("expanded")PolynomialExpansion...>>> px.transform(df).head().expandedDenseVector([0.5, 0.25, 2.0, 1.0, 4.0])>>> px.setParams(outputCol="test").transform(df).head().testDenseVector([0.5, 0.25, 2.0, 1.0, 4.0])>>> polyExpansionPath = temp_path + "/poly-expansion">>> px.save(polyExpansionPath)>>> loadedPx = PolynomialExpansion.load(polyExpansionPath)>>> loadedPx.getDegree() == px.getDegree()True>>> loadedPx.transform(df).take(1) == px.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getDegree() Gets the value of degree or its default value.getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.continues on next page

456 Chapter 3. API Reference

Page 461: pyspark Documentation

pyspark Documentation, Release master

Table 143 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDegree(value) Sets the value of degree.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, degree, inputCol, . . . ]) Sets params for this PolynomialExpansion.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

degreeinputColoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

3.3. ML 457

Page 462: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getDegree()Gets the value of degree or its default value.

New in version 1.4.0.

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDegree(value)Sets the value of degree.

New in version 1.4.0.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

458 Chapter 3. API Reference

Page 463: pyspark Documentation

pyspark Documentation, Release master

setParams(self, \*, degree=2, inputCol=None, outputCol=None)Sets params for this PolynomialExpansion.

New in version 1.4.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

degree = Param(parent='undefined', name='degree', doc='the polynomial degree to expand (>= 1)')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

QuantileDiscretizer

class pyspark.ml.feature.QuantileDiscretizer(*args, **kwargs)QuantileDiscretizer takes a column with continuous features and outputs a column with binned cat-egorical features. The number of bins can be set using the numBuckets parameter. It is possible that thenumber of buckets used will be less than this value, for example, if there are too few distinct values of theinput to create enough distinct quantiles. Since 3.0.0, QuantileDiscretizer can map multiple columnsat once by setting the inputCols parameter. If both of the inputCol and inputCols parameters are set,an Exception will be thrown. To specify the number of buckets for each column, the numBucketsArrayparameter can be set, or if the number of buckets should be the same across columns, numBuckets can be setas a convenience.

New in version 2.0.0.

3.3. ML 459

Page 464: pyspark Documentation

pyspark Documentation, Release master

Notes

NaN handling: Note also that QuantileDiscretizer will raise an error when it finds NaN values inthe dataset, but the user can also choose to either keep or remove NaN values within the dataset by settinghandleInvalid parameter. If the user chooses to keep NaN values, they will be handled specially and placedinto their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], butNaNs will be counted in a special bucket[4].

Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation forapproxQuantile() for a detailed description). The precision of the approximation can be controlled withthe relativeError parameter. The lower and upper bin bounds will be -Infinity and +Infinity, covering allreal values.

Examples

>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]>>> df1 = spark.createDataFrame(values, ["values"])>>> qds1 = QuantileDiscretizer(inputCol="values", outputCol="buckets")>>> qds1.setNumBuckets(2)QuantileDiscretizer...>>> qds1.setRelativeError(0.01)QuantileDiscretizer...>>> qds1.setHandleInvalid("error")QuantileDiscretizer...>>> qds1.getRelativeError()0.01>>> bucketizer = qds1.fit(df1)>>> qds1.setHandleInvalid("keep").fit(df1).transform(df1).count()6>>> qds1.setHandleInvalid("skip").fit(df1).transform(df1).count()4>>> splits = bucketizer.getSplits()>>> splits[0]-inf>>> print("%2.1f" % round(splits[1], 1))0.4>>> bucketed = bucketizer.transform(df1).head()>>> bucketed.buckets0.0>>> quantileDiscretizerPath = temp_path + "/quantile-discretizer">>> qds1.save(quantileDiscretizerPath)>>> loadedQds = QuantileDiscretizer.load(quantileDiscretizerPath)>>> loadedQds.getNumBuckets() == qds1.getNumBuckets()True>>> inputs = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, 1.5),... (float("nan"), float("nan")), (float("nan"), float("nan"))]>>> df2 = spark.createDataFrame(inputs, ["input1", "input2"])>>> qds2 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error",→˓numBuckets=2,... inputCols=["input1", "input2"], outputCols=["output1", "output2"])>>> qds2.getRelativeError()0.01>>> qds2.setHandleInvalid("keep").fit(df2).transform(df2).show()+------+------+-------+-------+|input1|input2|output1|output2|+------+------+-------+-------+

(continues on next page)

460 Chapter 3. API Reference

Page 465: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

| 0.1| 0.0| 0.0| 0.0|| 0.4| 1.0| 1.0| 1.0|| 1.2| 1.3| 1.0| 1.0|| 1.5| 1.5| 1.0| 1.0|| NaN| NaN| 2.0| 2.0|| NaN| NaN| 2.0| 2.0|+------+------+-------+-------+...>>> qds3 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error",... numBucketsArray=[5, 10], inputCols=["input1", "input2"],... outputCols=["output1", "output2"])>>> qds3.setHandleInvalid("skip").fit(df2).transform(df2).show()+------+------+-------+-------+|input1|input2|output1|output2|+------+------+-------+-------+| 0.1| 0.0| 1.0| 1.0|| 0.4| 1.0| 2.0| 2.0|| 1.2| 1.3| 3.0| 3.0|| 1.5| 1.5| 4.0| 4.0|+------+------+-------+-------+...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getNumBuckets() Gets the value of numBuckets or its default value.getNumBucketsArray() Gets the value of numBucketsArray or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.continues on next page

3.3. ML 461

Page 466: pyspark Documentation

pyspark Documentation, Release master

Table 145 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setNumBuckets(value) Sets the value of numBuckets.setNumBucketsArray(value) Sets the value of numBucketsArray .setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, numBuckets, inputCol, . . . ]) Set the params for the QuantileDiscretizersetRelativeError(value) Sets the value of relativeError.write() Returns an MLWriter instance for this ML instance.

Attributes

handleInvalidinputColinputColsnumBucketsnumBucketsArrayoutputColoutputColsparams Returns all params ordered by name.relativeError

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

462 Chapter 3. API Reference

Page 467: pyspark Documentation

pyspark Documentation, Release master

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

3.3. ML 463

Page 468: pyspark Documentation

pyspark Documentation, Release master

getInputCols()Gets the value of inputCols or its default value.

getNumBuckets()Gets the value of numBuckets or its default value.

New in version 2.0.0.

getNumBucketsArray()Gets the value of numBucketsArray or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getRelativeError()Gets the value of relativeError or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setHandleInvalid(value)Sets the value of handleInvalid.

setInputCol(value)Sets the value of inputCol.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

464 Chapter 3. API Reference

Page 469: pyspark Documentation

pyspark Documentation, Release master

setNumBuckets(value)Sets the value of numBuckets.

New in version 2.0.0.

setNumBucketsArray(value)Sets the value of numBucketsArray .

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

setParams(self, \*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, handleIn-valid="error", numBucketsArray=None, inputCols=None, outputCols=None)

Set the params for the QuantileDiscretizer

New in version 2.0.0.

setRelativeError(value)Sets the value of relativeError.

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries. Options are skip (filter out rows with invalid values), error (throw an error), or keep (keep invalid values in a special additional bucket). Note that in the multiple columns case, the invalid handling is applied to all columns. That said for 'error' it will throw an error if any invalids are found in any columns, for 'skip' it will skip rows with any invalids in any columns, etc.")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

numBuckets = Param(parent='undefined', name='numBuckets', doc='Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must be >= 2.')

numBucketsArray = Param(parent='undefined', name='numBucketsArray', doc='Array of number of buckets (quantiles, or categories) into which data points are grouped. This is for multiple columns input. If transforming multiple columns and numBucketsArray is not set, but numBuckets is set, then numBuckets will be applied across all columns.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')

3.3. ML 465

Page 470: pyspark Documentation

pyspark Documentation, Release master

RobustScaler

class pyspark.ml.feature.RobustScaler(*args, **kwargs)RobustScaler removes the median and scales the data according to the quantile range. The quantile range is bydefault IQR (Interquartile Range, quantile range between the 1st quartile = 25th quantile and the 3rd quartile =75th quantile) but can be configured. Centering and scaling happen independently on each feature by computingthe relevant statistics on the samples in the training set. Median and quantile range are then stored to be usedon later data using the transform method. Note that NaN values are ignored in the computation of medians andranges.

New in version 3.0.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> data = [(0, Vectors.dense([0.0, 0.0]),),... (1, Vectors.dense([1.0, -1.0]),),... (2, Vectors.dense([2.0, -2.0]),),... (3, Vectors.dense([3.0, -3.0]),),... (4, Vectors.dense([4.0, -4.0]),),]>>> df = spark.createDataFrame(data, ["id", "features"])>>> scaler = RobustScaler()>>> scaler.setInputCol("features")RobustScaler...>>> scaler.setOutputCol("scaled")RobustScaler...>>> model = scaler.fit(df)>>> model.setOutputCol("output")RobustScalerModel...>>> model.medianDenseVector([2.0, -2.0])>>> model.rangeDenseVector([2.0, 2.0])>>> model.transform(df).collect()[1].outputDenseVector([0.5, -0.5])>>> scalerPath = temp_path + "/robust-scaler">>> scaler.save(scalerPath)>>> loadedScaler = RobustScaler.load(scalerPath)>>> loadedScaler.getWithCentering() == scaler.getWithCentering()True>>> loadedScaler.getWithScaling() == scaler.getWithScaling()True>>> modelPath = temp_path + "/robust-scaler-model">>> model.save(modelPath)>>> loadedModel = RobustScalerModel.load(modelPath)>>> loadedModel.median == model.medianTrue>>> loadedModel.range == model.rangeTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

466 Chapter 3. API Reference

Page 471: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getLower() Gets the value of lower or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.getUpper() Gets the value of upper or its default value.getWithCentering() Gets the value of withCentering or its default value.getWithScaling() Gets the value of withScaling or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setLower(value) Sets the value of lower.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, lower, upper, . . . ]) Sets params for this RobustScaler.setRelativeError(value) Sets the value of relativeError.setUpper(value) Sets the value of upper.setWithCentering(value) Sets the value of withCentering.setWithScaling(value) Sets the value of withScaling.

continues on next page

3.3. ML 467

Page 472: pyspark Documentation

pyspark Documentation, Release master

Table 147 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.

Attributes

inputColloweroutputColparams Returns all params ordered by name.relativeErrorupperwithCenteringwithScaling

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

468 Chapter 3. API Reference

Page 473: pyspark Documentation

pyspark Documentation, Release master

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getLower()Gets the value of lower or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getRelativeError()Gets the value of relativeError or its default value.

getUpper()Gets the value of upper or its default value.

New in version 3.0.0.

getWithCentering()Gets the value of withCentering or its default value.

New in version 3.0.0.

getWithScaling()Gets the value of withScaling or its default value.

New in version 3.0.0.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 469

Page 474: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setLower(value)Sets the value of lower.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

setParams(self, \*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, input-Col=None, outputCol=None, relativeError=0.001)

Sets params for this RobustScaler.

New in version 3.0.0.

setRelativeError(value)Sets the value of relativeError.

New in version 3.0.0.

setUpper(value)Sets the value of upper.

New in version 3.0.0.

setWithCentering(value)Sets the value of withCentering.

New in version 3.0.0.

setWithScaling(value)Sets the value of withScaling.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

470 Chapter 3. API Reference

Page 475: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

lower = Param(parent='undefined', name='lower', doc='Lower quantile to calculate quantile range')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')

upper = Param(parent='undefined', name='upper', doc='Upper quantile to calculate quantile range')

withCentering = Param(parent='undefined', name='withCentering', doc='Whether to center data with median')

withScaling = Param(parent='undefined', name='withScaling', doc='Whether to scale the data to quantile range')

RobustScalerModel

class pyspark.ml.feature.RobustScalerModel(java_model=None)Model fitted by RobustScaler.

New in version 3.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getLower() Gets the value of lower or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getRelativeError() Gets the value of relativeError or its default value.getUpper() Gets the value of upper or its default value.getWithCentering() Gets the value of withCentering or its default value.getWithScaling() Gets the value of withScaling or its default value.

continues on next page

3.3. ML 471

Page 476: pyspark Documentation

pyspark Documentation, Release master

Table 149 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputCollowermedian Median of the RobustScalerModel.outputColparams Returns all params ordered by name.range Quantile range of the RobustScalerModel.relativeErrorupperwithCenteringwithScaling

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

472 Chapter 3. API Reference

Page 477: pyspark Documentation

pyspark Documentation, Release master

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getLower()Gets the value of lower or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getRelativeError()Gets the value of relativeError or its default value.

getUpper()Gets the value of upper or its default value.

New in version 3.0.0.

getWithCentering()Gets the value of withCentering or its default value.

New in version 3.0.0.

getWithScaling()Gets the value of withScaling or its default value.

New in version 3.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

3.3. ML 473

Page 478: pyspark Documentation

pyspark Documentation, Release master

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

lower = Param(parent='undefined', name='lower', doc='Lower quantile to calculate quantile range')

medianMedian of the RobustScalerModel.

New in version 3.0.0.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

rangeQuantile range of the RobustScalerModel.

New in version 3.0.0.

relativeError = Param(parent='undefined', name='relativeError', doc='the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]')

upper = Param(parent='undefined', name='upper', doc='Upper quantile to calculate quantile range')

474 Chapter 3. API Reference

Page 479: pyspark Documentation

pyspark Documentation, Release master

withCentering = Param(parent='undefined', name='withCentering', doc='Whether to center data with median')

withScaling = Param(parent='undefined', name='withScaling', doc='Whether to scale the data to quantile range')

RegexTokenizer

class pyspark.ml.feature.RegexTokenizer(*args, **kwargs)A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to splitthe text (default) or repeatedly matching the regex (if gaps is false). Optional parameters also allow filteringtokens using a minimal length. It returns an array of strings that can be empty.

New in version 1.4.0.

Examples

>>> df = spark.createDataFrame([("A B c",)], ["text"])>>> reTokenizer = RegexTokenizer()>>> reTokenizer.setInputCol("text")RegexTokenizer...>>> reTokenizer.setOutputCol("words")RegexTokenizer...>>> reTokenizer.transform(df).head()Row(text='A B c', words=['a', 'b', 'c'])>>> # Change a parameter.>>> reTokenizer.setParams(outputCol="tokens").transform(df).head()Row(text='A B c', tokens=['a', 'b', 'c'])>>> # Temporarily modify a parameter.>>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head()Row(text='A B c', words=['a', 'b', 'c'])>>> reTokenizer.transform(df).head()Row(text='A B c', tokens=['a', 'b', 'c'])>>> # Must use keyword arguments to specify params.>>> reTokenizer.setParams("text")Traceback (most recent call last):

...TypeError: Method setParams forces keyword arguments.>>> regexTokenizerPath = temp_path + "/regex-tokenizer">>> reTokenizer.save(regexTokenizerPath)>>> loadedReTokenizer = RegexTokenizer.load(regexTokenizerPath)>>> loadedReTokenizer.getMinTokenLength() == reTokenizer.getMinTokenLength()True>>> loadedReTokenizer.getGaps() == reTokenizer.getGaps()True>>> loadedReTokenizer.transform(df).take(1) == reTokenizer.transform(df).take(1)True

3.3. ML 475

Page 480: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getGaps() Gets the value of gaps or its default value.getInputCol() Gets the value of inputCol or its default value.getMinTokenLength() Gets the value of minTokenLength or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getPattern() Gets the value of pattern or its default value.getToLowercase() Gets the value of toLowercase or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setGaps(value) Sets the value of gaps.setInputCol(value) Sets the value of inputCol.setMinTokenLength(value) Sets the value of minTokenLength.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minTokenLength, gaps, . . . ]) Sets params for this RegexTokenizer.setPattern(value) Sets the value of pattern.setToLowercase(value) Sets the value of toLowercase.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

476 Chapter 3. API Reference

Page 481: pyspark Documentation

pyspark Documentation, Release master

Attributes

gapsinputColminTokenLengthoutputColparams Returns all params ordered by name.patterntoLowercase

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getGaps()Gets the value of gaps or its default value.

New in version 1.4.0.

getInputCol()Gets the value of inputCol or its default value.

getMinTokenLength()Gets the value of minTokenLength or its default value.

New in version 1.4.0.

3.3. ML 477

Page 482: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getPattern()Gets the value of pattern or its default value.

New in version 1.4.0.

getToLowercase()Gets the value of toLowercase or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setGaps(value)Sets the value of gaps.

New in version 1.4.0.

setInputCol(value)Sets the value of inputCol.

setMinTokenLength(value)Sets the value of minTokenLength.

New in version 1.4.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, minTokenLength=1, gaps=True, pattern="\s+", inputCol=None, outputCol=None,toLowercase=True)

Sets params for this RegexTokenizer.

New in version 1.4.0.

478 Chapter 3. API Reference

Page 483: pyspark Documentation

pyspark Documentation, Release master

setPattern(value)Sets the value of pattern.

New in version 1.4.0.

setToLowercase(value)Sets the value of toLowercase.

New in version 2.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

gaps = Param(parent='undefined', name='gaps', doc='whether regex splits on gaps (True) or matches tokens (False)')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

minTokenLength = Param(parent='undefined', name='minTokenLength', doc='minimum token length (>= 0)')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

pattern = Param(parent='undefined', name='pattern', doc='regex pattern (Java dialect) used for tokenizing')

toLowercase = Param(parent='undefined', name='toLowercase', doc='whether to convert all characters to lowercase before tokenizing')

RFormula

class pyspark.ml.feature.RFormula(*args, **kwargs)Implements the transforms required for fitting a dataset against an R model formula. Currently we support alimited subset of the R operators, including ‘~’, ‘.’, ‘:’, ‘+’, ‘-‘, ‘*’, and ‘^’.

New in version 1.5.0.

3.3. ML 479

Page 484: pyspark Documentation

pyspark Documentation, Release master

Notes

Also see the R formula docs.

Examples

>>> df = spark.createDataFrame([... (1.0, 1.0, "a"),... (0.0, 2.0, "b"),... (0.0, 0.0, "a")... ], ["y", "x", "s"])>>> rf = RFormula(formula="y ~ x + s")>>> model = rf.fit(df)>>> model.getLabelCol()'label'>>> model.transform(df).show()+---+---+---+---------+-----+| y| x| s| features|label|+---+---+---+---------+-----+|1.0|1.0| a|[1.0,1.0]| 1.0||0.0|2.0| b|[2.0,0.0]| 0.0||0.0|0.0| a|[0.0,1.0]| 0.0|+---+---+---+---------+-----+...>>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()+---+---+---+--------+-----+| y| x| s|features|label|+---+---+---+--------+-----+|1.0|1.0| a| [1.0]| 1.0||0.0|2.0| b| [2.0]| 0.0||0.0|0.0| a| [0.0]| 0.0|+---+---+---+--------+-----+...>>> rFormulaPath = temp_path + "/rFormula">>> rf.save(rFormulaPath)>>> loadedRF = RFormula.load(rFormulaPath)>>> loadedRF.getFormula() == rf.getFormula()True>>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()True>>> loadedRF.getLabelCol() == rf.getLabelCol()True>>> loadedRF.getHandleInvalid() == rf.getHandleInvalid()True>>> str(loadedRF)'RFormula(y ~ x + s) (uid=...)'>>> modelPath = temp_path + "/rFormulaModel">>> model.save(modelPath)>>> loadedModel = RFormulaModel.load(modelPath)>>> loadedModel.uid == model.uidTrue>>> loadedModel.transform(df).show()+---+---+---+---------+-----+| y| x| s| features|label|+---+---+---+---------+-----+|1.0|1.0| a|[1.0,1.0]| 1.0|

(continues on next page)

480 Chapter 3. API Reference

Page 485: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

|0.0|2.0| b|[2.0,0.0]| 0.0||0.0|0.0| a|[0.0,1.0]| 0.0|+---+---+---+---------+-----+...>>> str(loadedModel)'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...→˓)'

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFeaturesCol() Gets the value of featuresCol or its default value.getForceIndexLabel() Gets the value of forceIndexLabel.getFormula() Gets the value of formula.getHandleInvalid() Gets the value of handleInvalid or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getStringIndexerOrderType() Gets the value of stringIndexerOrderType

or its default value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.continues on next page

3.3. ML 481

Page 486: pyspark Documentation

pyspark Documentation, Release master

Table 153 – continued from previous pageset(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setForceIndexLabel(value) Sets the value of forceIndexLabel.setFormula(value) Sets the value of formula.setHandleInvalid(value) Sets the value of handleInvalid.setLabelCol(value) Sets the value of labelCol.setParams(self, \*[, formula, featuresCol, . . . ]) Sets params for RFormula.setStringIndexerOrderType(value) Sets the value of stringIndexerOrderType.write() Returns an MLWriter instance for this ML instance.

Attributes

featuresColforceIndexLabelformulahandleInvalidlabelColparams Returns all params ordered by name.stringIndexerOrderType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

482 Chapter 3. API Reference

Page 487: pyspark Documentation

pyspark Documentation, Release master

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFeaturesCol()Gets the value of featuresCol or its default value.

getForceIndexLabel()Gets the value of forceIndexLabel.

New in version 2.1.0.

getFormula()Gets the value of formula.

New in version 1.5.0.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getStringIndexerOrderType()Gets the value of stringIndexerOrderType or its default value ‘frequencyDesc’.

New in version 2.3.0.

3.3. ML 483

Page 488: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

setForceIndexLabel(value)Sets the value of forceIndexLabel.

New in version 2.1.0.

setFormula(value)Sets the value of formula.

New in version 1.5.0.

setHandleInvalid(value)Sets the value of handleInvalid.

setLabelCol(value)Sets the value of labelCol.

setParams(self, \*, formula=None, featuresCol="features", labelCol="label", forceIndexLabel=False,stringIndexerOrderType="frequencyDesc", handleInvalid="error")

Sets params for RFormula.

New in version 1.5.0.

setStringIndexerOrderType(value)Sets the value of stringIndexerOrderType.

New in version 2.3.0.

write()Returns an MLWriter instance for this ML instance.

484 Chapter 3. API Reference

Page 489: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

forceIndexLabel = Param(parent='undefined', name='forceIndexLabel', doc='Force to index label whether it is numeric or string')

formula = Param(parent='undefined', name='formula', doc='R model formula')

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries. Options are 'skip' (filter out rows with invalid values), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stringIndexerOrderType = Param(parent='undefined', name='stringIndexerOrderType', doc='How to order categories of a string feature column used by StringIndexer. The last category after ordering is dropped when encoding strings. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. The default value is frequencyDesc. When the ordering is set to alphabetDesc, RFormula drops the same category as R when encoding strings.')

RFormulaModel

class pyspark.ml.feature.RFormulaModel(java_model=None)Model fitted by RFormula. Fitting is required to determine the factor levels of formula terms.

New in version 1.5.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFeaturesCol() Gets the value of featuresCol or its default value.getForceIndexLabel() Gets the value of forceIndexLabel.getFormula() Gets the value of formula.getHandleInvalid() Gets the value of handleInvalid or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getStringIndexerOrderType() Gets the value of stringIndexerOrderType

or its default value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.

continues on next page

3.3. ML 485

Page 490: pyspark Documentation

pyspark Documentation, Release master

Table 155 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

featuresColforceIndexLabelformulahandleInvalidlabelColparams Returns all params ordered by name.stringIndexerOrderType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

486 Chapter 3. API Reference

Page 491: pyspark Documentation

pyspark Documentation, Release master

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getFeaturesCol()Gets the value of featuresCol or its default value.

getForceIndexLabel()Gets the value of forceIndexLabel.

New in version 2.1.0.

getFormula()Gets the value of formula.

New in version 1.5.0.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getStringIndexerOrderType()Gets the value of stringIndexerOrderType or its default value ‘frequencyDesc’.

New in version 2.3.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

3.3. ML 487

Page 492: pyspark Documentation

pyspark Documentation, Release master

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

forceIndexLabel = Param(parent='undefined', name='forceIndexLabel', doc='Force to index label whether it is numeric or string')

formula = Param(parent='undefined', name='formula', doc='R model formula')

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid entries. Options are 'skip' (filter out rows with invalid values), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stringIndexerOrderType = Param(parent='undefined', name='stringIndexerOrderType', doc='How to order categories of a string feature column used by StringIndexer. The last category after ordering is dropped when encoding strings. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. The default value is frequencyDesc. When the ordering is set to alphabetDesc, RFormula drops the same category as R when encoding strings.')

SQLTransformer

class pyspark.ml.feature.SQLTransformer(*args, **kwargs)Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax likeSELECT . . . FROM __THIS__ where __THIS__ represents the underlying table of the input dataset.

New in version 1.6.0.

Examples

>>> df = spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"])>>> sqlTrans = SQLTransformer(... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")>>> sqlTrans.transform(df).head()Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0)>>> sqlTransformerPath = temp_path + "/sql-transformer">>> sqlTrans.save(sqlTransformerPath)>>> loadedSqlTrans = SQLTransformer.load(sqlTransformerPath)>>> loadedSqlTrans.getStatement() == sqlTrans.getStatement()True>>> loadedSqlTrans.transform(df).take(1) == sqlTrans.transform(df).take(1)True

488 Chapter 3. API Reference

Page 493: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.getStatement() Gets the value of statement or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setParams(self, \*[, statement]) Sets params for this SQLTransformer.setStatement(value) Sets the value of statement.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

params Returns all params ordered by name.statement

3.3. ML 489

Page 494: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getStatement()Gets the value of statement or its default value.

New in version 1.6.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

490 Chapter 3. API Reference

Page 495: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setParams(self, \*, statement=None)Sets params for this SQLTransformer.

New in version 1.6.0.

setStatement(value)Sets the value of statement.

New in version 1.6.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

statement = Param(parent='undefined', name='statement', doc='SQL statement')

StandardScaler

class pyspark.ml.feature.StandardScaler(*args, **kwargs)Standardizes features by removing the mean and scaling to unit variance using column summary statistics onthe samples in the training set.

The “unit std” is computed using the corrected sample standard deviation, which is computed as the square rootof the unbiased sample variance.

New in version 1.4.0.

3.3. ML 491

Page 496: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)],→˓ ["a"])>>> standardScaler = StandardScaler()>>> standardScaler.setInputCol("a")StandardScaler...>>> standardScaler.setOutputCol("scaled")StandardScaler...>>> model = standardScaler.fit(df)>>> model.getInputCol()'a'>>> model.setOutputCol("output")StandardScalerModel...>>> model.meanDenseVector([1.0])>>> model.stdDenseVector([1.4142])>>> model.transform(df).collect()[1].outputDenseVector([1.4142])>>> standardScalerPath = temp_path + "/standard-scaler">>> standardScaler.save(standardScalerPath)>>> loadedStandardScaler = StandardScaler.load(standardScalerPath)>>> loadedStandardScaler.getWithMean() == standardScaler.getWithMean()True>>> loadedStandardScaler.getWithStd() == standardScaler.getWithStd()True>>> modelPath = temp_path + "/standard-scaler-model">>> model.save(modelPath)>>> loadedModel = StandardScalerModel.load(modelPath)>>> loadedModel.std == model.stdTrue>>> loadedModel.mean == model.meanTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

492 Chapter 3. API Reference

Page 497: pyspark Documentation

pyspark Documentation, Release master

Table 159 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getWithMean() Gets the value of withMean or its default value.getWithStd() Gets the value of withStd or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, withMean, withStd, . . . ]) Sets params for this StandardScaler.setWithMean(value) Sets the value of withMean.setWithStd(value) Sets the value of withStd.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColoutputColparams Returns all params ordered by name.withMeanwithStd

3.3. ML 493

Page 498: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

494 Chapter 3. API Reference

Page 499: pyspark Documentation

pyspark Documentation, Release master

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getWithMean()Gets the value of withMean or its default value.

New in version 1.4.0.

getWithStd()Gets the value of withStd or its default value.

New in version 1.4.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, withMean=False, withStd=True, inputCol=None, outputCol=None)Sets params for this StandardScaler.

New in version 1.4.0.

3.3. ML 495

Page 500: pyspark Documentation

pyspark Documentation, Release master

setWithMean(value)Sets the value of withMean.

New in version 1.4.0.

setWithStd(value)Sets the value of withStd.

New in version 1.4.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

withMean = Param(parent='undefined', name='withMean', doc='Center data with mean')

withStd = Param(parent='undefined', name='withStd', doc='Scale to unit standard deviation')

StandardScalerModel

class pyspark.ml.feature.StandardScalerModel(java_model=None)Model fitted by StandardScaler.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.continues on next page

496 Chapter 3. API Reference

Page 501: pyspark Documentation

pyspark Documentation, Release master

Table 161 – continued from previous pagegetOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getWithMean() Gets the value of withMean or its default value.getWithStd() Gets the value of withStd or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColmean Mean of the StandardScalerModel.outputColparams Returns all params ordered by name.std Standard deviation of the StandardScalerModel.withMeanwithStd

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in a

3.3. ML 497

Page 502: pyspark Documentation

pyspark Documentation, Release master

string.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getWithMean()Gets the value of withMean or its default value.

New in version 1.4.0.

getWithStd()Gets the value of withStd or its default value.

New in version 1.4.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

498 Chapter 3. API Reference

Page 503: pyspark Documentation

pyspark Documentation, Release master

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

meanMean of the StandardScalerModel.

New in version 2.0.0.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stdStandard deviation of the StandardScalerModel.

New in version 2.0.0.

withMean = Param(parent='undefined', name='withMean', doc='Center data with mean')

withStd = Param(parent='undefined', name='withStd', doc='Scale to unit standard deviation')

StopWordsRemover

class pyspark.ml.feature.StopWordsRemover(*args, **kwargs)A feature transformer that filters out stop words from input. Since 3.0.0, StopWordsRemover can filter outmultiple columns at once by setting the inputCols parameter. Note that when both the inputCol andinputCols parameters are set, an Exception will be thrown.

New in version 1.6.0.

3.3. ML 499

Page 504: pyspark Documentation

pyspark Documentation, Release master

Notes

null values from input array are preserved unless adding null to stopWords explicitly.

Examples

>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"])>>> remover = StopWordsRemover(stopWords=["b"])>>> remover.setInputCol("text")StopWordsRemover...>>> remover.setOutputCol("words")StopWordsRemover...>>> remover.transform(df).head().words == ['a', 'c']True>>> stopWordsRemoverPath = temp_path + "/stopwords-remover">>> remover.save(stopWordsRemoverPath)>>> loadedRemover = StopWordsRemover.load(stopWordsRemoverPath)>>> loadedRemover.getStopWords() == remover.getStopWords()True>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()True>>> loadedRemover.transform(df).take(1) == remover.transform(df).take(1)True>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2→˓"])>>> remover2 = StopWordsRemover(stopWords=["b"])>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])StopWordsRemover...>>> remover2.transform(df2).show()+---------+------+------+------+| text1| text2|words1|words2|+---------+------+------+------+|[a, b, c]|[a, b]|[a, c]| [a]|+---------+------+------+------+...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

500 Chapter 3. API Reference

Page 505: pyspark Documentation

pyspark Documentation, Release master

Table 163 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCaseSensitive() Gets the value of caseSensitive or its defaultvalue.

getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getLocale() Gets the value of locale.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getStopWords() Gets the value of stopWords or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).loadDefaultStopWords(language) Loads the default stop words for the given language.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCaseSensitive(value) Sets the value of caseSensitive.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setLocale(value) Sets the value of locale.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this StopWordRemover.setStopWords(value) Sets the value of stopWords.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

3.3. ML 501

Page 506: pyspark Documentation

pyspark Documentation, Release master

Attributes

caseSensitiveinputColinputColslocaleoutputColoutputColsparams Returns all params ordered by name.stopWords

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCaseSensitive()Gets the value of caseSensitive or its default value.

New in version 1.6.0.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

502 Chapter 3. API Reference

Page 507: pyspark Documentation

pyspark Documentation, Release master

getLocale()Gets the value of locale.

New in version 2.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getStopWords()Gets the value of stopWords or its default value.

New in version 1.6.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

static loadDefaultStopWords(language)Loads the default stop words for the given language. Supported languages: danish, dutch, english, finnish,french, german, hungarian, italian, norwegian, portuguese, russian, spanish, swedish, turkish

New in version 2.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCaseSensitive(value)Sets the value of caseSensitive.

New in version 1.6.0.

setInputCol(value)Sets the value of inputCol.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

3.3. ML 503

Page 508: pyspark Documentation

pyspark Documentation, Release master

setLocale(value)Sets the value of locale.

New in version 2.4.0.

setOutputCol(value)Sets the value of outputCol.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

setParams(self, \*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, lo-cale=None, inputCols=None, outputCols=None)

Sets params for this StopWordRemover.

New in version 1.6.0.

setStopWords(value)Sets the value of stopWords.

New in version 1.6.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

caseSensitive = Param(parent='undefined', name='caseSensitive', doc='whether to do a case sensitive comparison over the stop words')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

locale = Param(parent='undefined', name='locale', doc='locale of the input. ignored when case sensitive is true')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stopWords = Param(parent='undefined', name='stopWords', doc='The words to be filtered out')

504 Chapter 3. API Reference

Page 509: pyspark Documentation

pyspark Documentation, Release master

StringIndexer

class pyspark.ml.feature.StringIndexer(*args, **kwargs)A label indexer that maps a string column of labels to an ML column of label indices. If the input column isnumeric, we cast it to string and index the string values. The indices are in [0, numLabels). By default, thisis ordered by label frequencies so the most frequent label gets index 0. The ordering behavior is controlled bysetting stringOrderType. Its default value is ‘frequencyDesc’.

New in version 1.4.0.

Examples

>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",... stringOrderType="frequencyDesc")>>> stringIndexer.setHandleInvalid("error")StringIndexer...>>> model = stringIndexer.fit(stringIndDf)>>> model.setHandleInvalid("error")StringIndexerModel...>>> td = model.transform(stringIndDf)>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),... key=lambda x: x[0])[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]>>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.→˓labels)>>> itd = inverter.transform(td)>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).→˓collect()]),... key=lambda x: x[0])[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]>>> stringIndexerPath = temp_path + "/string-indexer">>> stringIndexer.save(stringIndexerPath)>>> loadedIndexer = StringIndexer.load(stringIndexerPath)>>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid()True>>> modelPath = temp_path + "/string-indexer-model">>> model.save(modelPath)>>> loadedModel = StringIndexerModel.load(modelPath)>>> loadedModel.labels == model.labelsTrue>>> indexToStringPath = temp_path + "/index-to-string">>> inverter.save(indexToStringPath)>>> loadedInverter = IndexToString.load(indexToStringPath)>>> loadedInverter.getLabels() == inverter.getLabels()True>>> loadedModel.transform(stringIndDf).take(1) == model.transform(stringIndDf).→˓take(1)True>>> stringIndexer.getStringOrderType()'frequencyDesc'>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",→˓handleInvalid="error",... stringOrderType="alphabetDesc")>>> model = stringIndexer.fit(stringIndDf)>>> td = model.transform(stringIndDf)>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),

(continues on next page)

3.3. ML 505

Page 510: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

... key=lambda x: x[0])[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]>>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],... inputCol="label", outputCol="indexed", handleInvalid="error")>>> result = fromlabelsModel.transform(stringIndDf)>>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).→˓collect()]),... key=lambda x: x[0])[(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]>>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),... Row(id=1, label1="b", label2="f"),... Row(id=2, label1="c", label2="e"),... Row(id=3, label1="a", label2="f"),... Row(id=4, label1="a", label2="f"),... Row(id=5, label1="c", label2="f")], 3)>>> multiRowDf = spark.createDataFrame(testData)>>> inputs = ["label1", "label2"]>>> outputs = ["index1", "index2"]>>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)>>> model = stringIndexer.fit(multiRowDf)>>> result = model.transform(multiRowDf)>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.→˓index1,... result.index2).collect()]), key=lambda x: x[0])[(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.→˓0, 0.0)]>>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], [→˓"e", "f"]],... inputCols=inputs, outputCols=outputs)>>> result = fromlabelsModel.transform(multiRowDf)>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.→˓index1,... result.index2).collect()]), key=lambda x: x[0])[(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.→˓0, 1.0)]

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

506 Chapter 3. API Reference

Page 511: pyspark Documentation

pyspark Documentation, Release master

Table 165 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getStringOrderType() Gets the value of stringOrderType or its default

value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.setParams(self, \*[, inputCol, outputCol, . . . ]) Sets params for this StringIndexer.setStringOrderType(value) Sets the value of stringOrderType.write() Returns an MLWriter instance for this ML instance.

Attributes

handleInvalidinputColinputColsoutputColoutputColsparams Returns all params ordered by name.

continues on next page

3.3. ML 507

Page 512: pyspark Documentation

pyspark Documentation, Release master

Table 166 – continued from previous pagestringOrderType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

508 Chapter 3. API Reference

Page 513: pyspark Documentation

pyspark Documentation, Release master

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getStringOrderType()Gets the value of stringOrderType or its default value ‘frequencyDesc’.

New in version 2.3.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setHandleInvalid(value)Sets the value of handleInvalid.

3.3. ML 509

Page 514: pyspark Documentation

pyspark Documentation, Release master

setInputCol(value)Sets the value of inputCol.

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

setParams(self, \*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleIn-valid="error", stringOrderType="frequencyDesc")

Sets params for this StringIndexer.

New in version 1.4.0.

setStringOrderType(value)Sets the value of stringOrderType.

New in version 2.3.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid data (unseen or NULL values) in features and label column of string type. Options are 'skip' (filter out rows with invalid data), error (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stringOrderType = Param(parent='undefined', name='stringOrderType', doc='How to order labels of string column. The first label after ordering is assigned an index of 0. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. Default is frequencyDesc. In case of equal frequency when under frequencyDesc/Asc, the strings are further sorted alphabetically')

StringIndexerModel

class pyspark.ml.feature.StringIndexerModel(java_model=None)Model fitted by StringIndexer.

New in version 1.4.0.

510 Chapter 3. API Reference

Page 515: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

from_arrays_of_labels(arrayOfLabels, in-putCols)

Construct the model directly from an array of arrayof label strings, requires an active SparkContext.

from_labels(labels, inputCol[, outputCol, . . . ]) Construct the model directly from an array of labelstrings, requires an active SparkContext.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getOutputCols() Gets the value of outputCols or its default value.getParam(paramName) Gets a param by its name.getStringOrderType() Gets the value of stringOrderType or its default

value ‘frequencyDesc’.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setOutputCols(value) Sets the value of outputCols.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

3.3. ML 511

Page 516: pyspark Documentation

pyspark Documentation, Release master

Attributes

handleInvalidinputColinputColslabels Ordered list of labels, corresponding to indices to be

assigned.labelsArray Array of ordered list of labels, corresponding to in-

dices to be assigned for each input column.outputColoutputColsparams Returns all params ordered by name.stringOrderType

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

classmethod from_arrays_of_labels(arrayOfLabels, inputCols, outputCols=None, han-dleInvalid=None)

Construct the model directly from an array of array of label strings, requires an active SparkContext.

New in version 3.0.0.

512 Chapter 3. API Reference

Page 517: pyspark Documentation

pyspark Documentation, Release master

classmethod from_labels(labels, inputCol, outputCol=None, handleInvalid=None)Construct the model directly from an array of label strings, requires an active SparkContext.

New in version 2.4.0.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getOutputCols()Gets the value of outputCols or its default value.

getParam(paramName)Gets a param by its name.

getStringOrderType()Gets the value of stringOrderType or its default value ‘frequencyDesc’.

New in version 2.3.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setHandleInvalid(value)Sets the value of handleInvalid.

New in version 2.4.0.

setInputCol(value)Sets the value of inputCol.

3.3. ML 513

Page 518: pyspark Documentation

pyspark Documentation, Release master

setInputCols(value)Sets the value of inputCols.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

setOutputCols(value)Sets the value of outputCols.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="how to handle invalid data (unseen or NULL values) in features and label column of string type. Options are 'skip' (filter out rows with invalid data), error (throw an error), or 'keep' (put invalid data in a special additional bucket, at index numLabels).")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

labelsOrdered list of labels, corresponding to indices to be assigned.

Deprecated since version 3.1.0: It will be removed in future versions. Use labelsArray method instead.

New in version 1.5.0.

labelsArrayArray of ordered list of labels, corresponding to indices to be assigned for each input column.

New in version 3.0.2.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

outputCols = Param(parent='undefined', name='outputCols', doc='output column names.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

stringOrderType = Param(parent='undefined', name='stringOrderType', doc='How to order labels of string column. The first label after ordering is assigned an index of 0. Supported options: frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. Default is frequencyDesc. In case of equal frequency when under frequencyDesc/Asc, the strings are further sorted alphabetically')

514 Chapter 3. API Reference

Page 519: pyspark Documentation

pyspark Documentation, Release master

Tokenizer

class pyspark.ml.feature.Tokenizer(*args, **kwargs)A tokenizer that converts the input string to lowercase and then splits it by white spaces.

New in version 1.3.0.

Examples

>>> df = spark.createDataFrame([("a b c",)], ["text"])>>> tokenizer = Tokenizer(outputCol="words")>>> tokenizer.setInputCol("text")Tokenizer...>>> tokenizer.transform(df).head()Row(text='a b c', words=['a', 'b', 'c'])>>> # Change a parameter.>>> tokenizer.setParams(outputCol="tokens").transform(df).head()Row(text='a b c', tokens=['a', 'b', 'c'])>>> # Temporarily modify a parameter.>>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()Row(text='a b c', words=['a', 'b', 'c'])>>> tokenizer.transform(df).head()Row(text='a b c', tokens=['a', 'b', 'c'])>>> # Must use keyword arguments to specify params.>>> tokenizer.setParams("text")Traceback (most recent call last):

...TypeError: Method setParams forces keyword arguments.>>> tokenizerPath = temp_path + "/tokenizer">>> tokenizer.save(tokenizerPath)>>> loadedTokenizer = Tokenizer.load(tokenizerPath)>>> loadedTokenizer.transform(df).head().tokens == tokenizer.transform(df).head().→˓tokensTrue

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

continues on next page

3.3. ML 515

Page 520: pyspark Documentation

pyspark Documentation, Release master

Table 169 – continued from previous pagegetInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCol, outputCol]) Sets params for this Tokenizer.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

516 Chapter 3. API Reference

Page 521: pyspark Documentation

pyspark Documentation, Release master

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inputCol=None, outputCol=None)Sets params for this Tokenizer.

New in version 1.3.0.

3.3. ML 517

Page 522: pyspark Documentation

pyspark Documentation, Release master

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

VarianceThresholdSelector

class pyspark.ml.feature.VarianceThresholdSelector(*args, **kwargs)Feature selector that removes all low-variance features. Features with a variance not greater than the thresholdwill be removed. The default is to keep all features with non-zero variance, i.e. remove the features that havethe same value in all samples.

New in version 3.1.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame(... [(Vectors.dense([6.0, 7.0, 0.0, 7.0, 6.0, 0.0]),),... (Vectors.dense([0.0, 9.0, 6.0, 0.0, 5.0, 9.0]),),... (Vectors.dense([0.0, 9.0, 3.0, 0.0, 5.0, 5.0]),),... (Vectors.dense([0.0, 9.0, 8.0, 5.0, 6.0, 4.0]),),... (Vectors.dense([8.0, 9.0, 6.0, 5.0, 4.0, 4.0]),),... (Vectors.dense([8.0, 9.0, 6.0, 0.0, 0.0, 0.0]),)],... ["features"])>>> selector = VarianceThresholdSelector(varianceThreshold=8.2, outputCol=→˓"selectedFeatures")>>> model = selector.fit(df)>>> model.getFeaturesCol()'features'>>> model.setFeaturesCol("features")VarianceThresholdSelectorModel...>>> model.transform(df).head().selectedFeaturesDenseVector([6.0, 7.0, 0.0])

(continues on next page)

518 Chapter 3. API Reference

Page 523: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.selectedFeatures[0, 3, 5]>>> varianceThresholdSelectorPath = temp_path + "/variance-threshold-selector">>> selector.save(varianceThresholdSelectorPath)>>> loadedSelector = VarianceThresholdSelector.load(varianceThresholdSelectorPath)>>> loadedSelector.getVarianceThreshold() == selector.getVarianceThreshold()True>>> modelPath = temp_path + "/variance-threshold-selector-model">>> model.save(modelPath)>>> loadedModel = VarianceThresholdSelectorModel.load(modelPath)>>> loadedModel.selectedFeatures == model.selectedFeaturesTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFeaturesCol() Gets the value of featuresCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getVarianceThreshold() Gets the value of varianceThreshold or its default

value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.

continues on next page

3.3. ML 519

Page 524: pyspark Documentation

pyspark Documentation, Release master

Table 171 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, featuresCol, . . . ]) Sets params for this VarianceThresholdSelector.setVarianceThreshold(value) Sets the value of varianceThreshold.write() Returns an MLWriter instance for this ML instance.

Attributes

featuresColoutputColparams Returns all params ordered by name.varianceThreshold

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

520 Chapter 3. API Reference

Page 525: pyspark Documentation

pyspark Documentation, Release master

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFeaturesCol()Gets the value of featuresCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getVarianceThreshold()Gets the value of varianceThreshold or its default value.

New in version 3.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 521

Page 526: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.1.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.1.0.

setParams(self, \*, featuresCol="features", outputCol=None, varianceThreshold=0.0)Sets params for this VarianceThresholdSelector.

New in version 3.1.0.

setVarianceThreshold(value)Sets the value of varianceThreshold.

New in version 3.1.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

varianceThreshold = Param(parent='undefined', name='varianceThreshold', doc='Param for variance threshold. Features with a variance not greater than this threshold will be removed. The default value is 0.0.')

VarianceThresholdSelectorModel

class pyspark.ml.feature.VarianceThresholdSelectorModel(java_model=None)Model fitted by VarianceThresholdSelector.

New in version 3.1.0.

522 Chapter 3. API Reference

Page 527: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFeaturesCol() Gets the value of featuresCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getVarianceThreshold() Gets the value of varianceThreshold or its default

value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

3.3. ML 523

Page 528: pyspark Documentation

pyspark Documentation, Release master

Attributes

featuresColoutputColparams Returns all params ordered by name.selectedFeatures List of indices to select (filter).varianceThreshold

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getFeaturesCol()Gets the value of featuresCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

524 Chapter 3. API Reference

Page 529: pyspark Documentation

pyspark Documentation, Release master

getVarianceThreshold()Gets the value of varianceThreshold or its default value.

New in version 3.1.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.1.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.1.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 525

Page 530: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

selectedFeaturesList of indices to select (filter).

New in version 3.1.0.

varianceThreshold = Param(parent='undefined', name='varianceThreshold', doc='Param for variance threshold. Features with a variance not greater than this threshold will be removed. The default value is 0.0.')

VectorAssembler

class pyspark.ml.feature.VectorAssembler(*args, **kwargs)A feature transformer that merges multiple columns into a vector column.

New in version 1.4.0.

Examples

>>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"])>>> vecAssembler = VectorAssembler(outputCol="features")>>> vecAssembler.setInputCols(["a", "b", "c"])VectorAssembler...>>> vecAssembler.transform(df).head().featuresDenseVector([1.0, 0.0, 3.0])>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqsDenseVector([1.0, 0.0, 3.0])>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector→˓"}>>> vecAssembler.transform(df, params).head().vectorDenseVector([0.0, 1.0])>>> vectorAssemblerPath = temp_path + "/vector-assembler">>> vecAssembler.save(vectorAssemblerPath)>>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)>>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).→˓head().freqsTrue>>> dfWithNullsAndNaNs = spark.createDataFrame(... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b",→˓"c"])>>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features→˓",... handleInvalid="keep")>>> vecAssembler2.transform(dfWithNullsAndNaNs).show()+---+---+----+-------------+| a| b| c| features|+---+---+----+-------------+|1.0|2.0|null|[1.0,2.0,NaN]||3.0|NaN| 4.0|[3.0,NaN,4.0]||5.0|6.0| 7.0|[5.0,6.0,7.0]|

(continues on next page)

526 Chapter 3. API Reference

Page 531: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+---+---+----+-------------+...>>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).→˓show()+---+---+---+-------------+| a| b| c| features|+---+---+---+-------------+|5.0|6.0|7.0|[5.0,6.0,7.0]|+---+---+---+-------------+...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCols() Gets the value of inputCols or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCols(value) Sets the value of inputCols.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, inputCols, outputCol, . . . ]) Sets params for this VectorAssembler.

continues on next page

3.3. ML 527

Page 532: pyspark Documentation

pyspark Documentation, Release master

Table 175 – continued from previous pagetransform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

handleInvalidinputColsoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCols()Gets the value of inputCols or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

528 Chapter 3. API Reference

Page 533: pyspark Documentation

pyspark Documentation, Release master

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setHandleInvalid(value)Sets the value of handleInvalid.

setInputCols(value)Sets the value of inputCols.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, inputCols=None, outputCol=None, handleInvalid="error")Sets params for this VectorAssembler.

New in version 1.4.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 529

Page 534: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data (NULL and NaN values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the output). Column lengths are taken from the size of ML Attribute Group, which can be set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred from first rows of the data since it is safe to do so but only in case of 'error' or 'skip').")

inputCols = Param(parent='undefined', name='inputCols', doc='input column names.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

VectorIndexer

class pyspark.ml.feature.VectorIndexer(*args, **kwargs)Class for indexing categorical feature columns in a dataset of Vector.

This has 2 usage modes:

• Automatically identify categorical features (default behavior)

– This helps process a dataset of unknown vectors into a dataset with some continuous featuresand some categorical features. The choice between continuous and categorical is based upon amaxCategories parameter.

– Set maxCategories to the maximum number of categorical any categorical feature should have.

– E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. If maxCate-gories = 2, then feature 0 will be declared categorical and use indices {0, 1}, and feature 1 willbe declared continuous.

• Index all features, if all features are categorical

– If maxCategories is set to be very large, then this will build an index of unique values for allfeatures.

– Warning: This can cause problems if features are continuous since this will collect ALL uniquevalues to the driver.

– E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. If maxCate-gories >= 3, then both features will be declared categorical.

This returns a model which can transform categorical features to use 0-based indices.

Index stability:

• This is not guaranteed to choose the same category index across multiple runs.

• If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. Thismaintains vector sparsity.

• More stability may be added in the future.

TODO: Future extensions: The following functionality is planned for the future:

• Preserve metadata in transform; if a feature’s metadata is already present, do not recompute.

• Specify certain features to not index, either via a parameter or via existing metadata.

• Add warning if a categorical feature has only 1 category.

New in version 1.4.0.

530 Chapter 3. API Reference

Page 535: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"])>>> indexer = VectorIndexer(maxCategories=2, inputCol="a")>>> indexer.setOutputCol("indexed")VectorIndexer...>>> model = indexer.fit(df)>>> indexer.getHandleInvalid()'error'>>> model.setOutputCol("output")VectorIndexerModel...>>> model.transform(df).head().outputDenseVector([1.0, 0.0])>>> model.numFeatures2>>> model.categoryMaps{0: {0.0: 0, -1.0: 1}}>>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].testDenseVector([0.0, 1.0])>>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}>>> model2 = indexer.fit(df, params)>>> model2.transform(df).head().vectorDenseVector([1.0, 0.0])>>> vectorIndexerPath = temp_path + "/vector-indexer">>> indexer.save(vectorIndexerPath)>>> loadedIndexer = VectorIndexer.load(vectorIndexerPath)>>> loadedIndexer.getMaxCategories() == indexer.getMaxCategories()True>>> modelPath = temp_path + "/vector-indexer-model">>> model.save(modelPath)>>> loadedModel = VectorIndexerModel.load(modelPath)>>> loadedModel.numFeatures == model.numFeaturesTrue>>> loadedModel.categoryMaps == model.categoryMapsTrue>>> loadedModel.transform(df).take(1) == model.transform(df).take(1)True>>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])>>> indexer.getHandleInvalid()'error'>>> model3 = indexer.setHandleInvalid("skip").fit(df)>>> model3.transform(dfWithInvalid).count()0>>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)>>> model4.transform(dfWithInvalid).head().indexedDenseVector([2.0, 1.0])

3.3. ML 531

Page 536: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxCategories() Gets the value of maxCategories or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setMaxCategories(value) Sets the value of maxCategories.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, maxCategories, . . . ]) Sets params for this VectorIndexer.write() Returns an MLWriter instance for this ML instance.

532 Chapter 3. API Reference

Page 537: pyspark Documentation

pyspark Documentation, Release master

Attributes

handleInvalidinputColmaxCategoriesoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

3.3. ML 533

Page 538: pyspark Documentation

pyspark Documentation, Release master

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getMaxCategories()Gets the value of maxCategories or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

534 Chapter 3. API Reference

Page 539: pyspark Documentation

pyspark Documentation, Release master

setHandleInvalid(value)Sets the value of handleInvalid.

setInputCol(value)Sets the value of inputCol.

setMaxCategories(value)Sets the value of maxCategories.

New in version 1.4.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")Sets params for this VectorIndexer.

New in version 1.4.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data (unseen labels or NULL values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index of the number of categories of the feature).")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxCategories = Param(parent='undefined', name='maxCategories', doc='Threshold for the number of values a categorical feature can take (>= 2). If a feature is found to have > maxCategories values, then it is declared continuous.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

VectorIndexerModel

class pyspark.ml.feature.VectorIndexerModel(java_model=None)Model fitted by VectorIndexer.

Transform categorical features to use 0-based indices instead of their original values.

• Categorical features are mapped to indices.

• Continuous features (columns) are left unchanged.

This also appends metadata to the output column, marking features as Numeric (continuous), Nominal (categor-ical), or Binary (either continuous or categorical). Non-ML metadata is not carried over from the input to theoutput column.

This maintains vector sparsity.

New in version 1.4.0.

3.3. ML 535

Page 540: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getMaxCategories() Gets the value of maxCategories or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

536 Chapter 3. API Reference

Page 541: pyspark Documentation

pyspark Documentation, Release master

Attributes

categoryMaps Feature value index.handleInvalidinputColmaxCategoriesnumFeatures Number of features, i.e., length of Vectors which this

transforms.outputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getMaxCategories()Gets the value of maxCategories or its default value.

New in version 1.4.0.

3.3. ML 537

Page 542: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

New in version 3.0.0.

setOutputCol(value)Sets the value of outputCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

538 Chapter 3. API Reference

Page 543: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

categoryMapsFeature value index. Keys are categorical feature indices (column indices). Values are maps from originalfeatures values to 0-based category indices. If a feature is not in this map, it is treated as continuous.

New in version 1.4.0.

handleInvalid = Param(parent='undefined', name='handleInvalid', doc="How to handle invalid data (unseen labels or NULL values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (put invalid data in a special additional bucket, at index of the number of categories of the feature).")

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxCategories = Param(parent='undefined', name='maxCategories', doc='Threshold for the number of values a categorical feature can take (>= 2). If a feature is found to have > maxCategories values, then it is declared continuous.')

numFeaturesNumber of features, i.e., length of Vectors which this transforms.

New in version 1.4.0.

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

VectorSizeHint

class pyspark.ml.feature.VectorSizeHint(*args, **kwargs)A feature transformer that adds size information to the metadata of a vector column. VectorAssembler needssize information for its input columns and cannot be used on streaming dataframes without this metadata.

New in version 2.3.0.

Notes

VectorSizeHint modifies inputCol to include size metadata and does not have an outputCol.

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml import Pipeline, PipelineModel>>> data = [(Vectors.dense([1., 2., 3.]), 4.)]>>> df = spark.createDataFrame(data, ["vector", "float"])>>>>>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip")>>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol=→˓"assembled")>>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])>>>>>> pipelineModel = pipeline.fit(df)>>> pipelineModel.transform(df).head().assembledDenseVector([1.0, 2.0, 3.0, 4.0])>>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline">>> pipelineModel.save(vectorSizeHintPath)>>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)>>> loaded = loadedPipeline.transform(df).head().assembled>>> expected = pipelineModel.transform(df).head().assembled

(continues on next page)

3.3. ML 539

Page 544: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> loaded == expectedTrue

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getHandleInvalid() Gets the value of handleInvalid or its default value.getInputCol() Gets the value of inputCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSize() Gets size param, the size of vectors in inputCol.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setHandleInvalid(value) Sets the value of handleInvalid.setInputCol(value) Sets the value of inputCol.setParams(self, \*[, inputCol, size, . . . ]) Sets params for this VectorSizeHint.setSize(value) Sets size param, the size of vectors in inputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

540 Chapter 3. API Reference

Page 545: pyspark Documentation

pyspark Documentation, Release master

Attributes

handleInvalidinputColparams Returns all params ordered by name.size

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getHandleInvalid()Gets the value of handleInvalid or its default value.

getInputCol()Gets the value of inputCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSize()Gets size param, the size of vectors in inputCol.

3.3. ML 541

Page 546: pyspark Documentation

pyspark Documentation, Release master

New in version 2.3.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setHandleInvalid(value)Sets the value of handleInvalid.

setInputCol(value)Sets the value of inputCol.

setParams(self, \*, inputCol=None, size=None, handleInvalid="error")Sets params for this VectorSizeHint.

New in version 2.3.0.

setSize(value)Sets size param, the size of vectors in inputCol.

New in version 2.3.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

542 Chapter 3. API Reference

Page 547: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

handleInvalid = Param(parent='undefined', name='handleInvalid', doc='How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an error) and `optimistic` (do not check the vector size, and keep all rows). `error` by default.')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

size = Param(parent='undefined', name='size', doc='Size of vectors in column.')

VectorSlicer

class pyspark.ml.feature.VectorSlicer(*args, **kwargs)This class takes a feature vector and outputs a new feature vector with a subarray of the original features.

The subset of features can be specified with either indices (setIndices()) or names (setNames()). At least onefeature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indicesand names.

The output vector will order features with the selected indices first (in the order given), followed by the selectednames (in the order given).

New in version 1.6.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),),... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),),... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"])>>> vs = VectorSlicer(outputCol="sliced", indices=[1, 4])>>> vs.setInputCol("features")VectorSlicer...>>> vs.transform(df).head().slicedDenseVector([2.3, 1.0])>>> vectorSlicerPath = temp_path + "/vector-slicer">>> vs.save(vectorSlicerPath)>>> loadedVs = VectorSlicer.load(vectorSlicerPath)>>> loadedVs.getIndices() == vs.getIndices()True>>> loadedVs.getNames() == vs.getNames()True>>> loadedVs.transform(df).take(1) == vs.transform(df).take(1)True

3.3. ML 543

Page 548: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getIndices() Gets the value of indices or its default value.getInputCol() Gets the value of inputCol or its default value.getNames() Gets the value of names or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setIndices(value) Sets the value of indices.setInputCol(value) Sets the value of inputCol.setNames(value) Sets the value of names.setOutputCol(value) Sets the value of outputCol.setParams(*args, **kwargs) setParams(self, *, inputCol=None, outputCol=None,

indices=None, names=None): Sets params for thisVectorSlicer.

transform(dataset[, params]) Transforms the input dataset with optional parame-ters.

write() Returns an MLWriter instance for this ML instance.

544 Chapter 3. API Reference

Page 549: pyspark Documentation

pyspark Documentation, Release master

Attributes

indicesinputColnamesoutputColparams Returns all params ordered by name.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getIndices()Gets the value of indices or its default value.

New in version 1.6.0.

getInputCol()Gets the value of inputCol or its default value.

getNames()Gets the value of names or its default value.

New in version 1.6.0.

3.3. ML 545

Page 550: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setIndices(value)Sets the value of indices.

New in version 1.6.0.

setInputCol(value)Sets the value of inputCol.

setNames(value)Sets the value of names.

New in version 1.6.0.

setOutputCol(value)Sets the value of outputCol.

setParams(*args, **kwargs)setParams(self, *, inputCol=None, outputCol=None, indices=None, names=None): Sets params for thisVectorSlicer.

New in version 1.6.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

546 Chapter 3. API Reference

Page 551: pyspark Documentation

pyspark Documentation, Release master

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

indices = Param(parent='undefined', name='indices', doc='An array of indices to select features from a vector column. There can be no overlap with names.')

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

names = Param(parent='undefined', name='names', doc='An array of feature names to select features from a vector column. These names must be specified by ML org.apache.spark.ml.attribute.Attribute. There can be no overlap with indices.')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

Word2Vec

class pyspark.ml.feature.Word2Vec(*args, **kwargs)Word2Vec trains a model of Map(String, Vector), i.e. transforms a word into a code for further natural languageprocessing or machine learning process.

New in version 1.4.0.

Examples

>>> sent = ("a b " * 100 + "a c " * 10).split(" ")>>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])>>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol=→˓"model")>>> word2Vec.setMaxIter(10)Word2Vec...>>> word2Vec.getMaxIter()10>>> word2Vec.clear(word2Vec.maxIter)>>> model = word2Vec.fit(doc)>>> model.getMinCount()5>>> model.setInputCol("sentence")Word2VecModel...>>> model.getVectors().show()+----+--------------------+|word| vector|+----+--------------------+| a|[0.09511678665876...|| b|[-1.2028766870498...|| c|[0.30153277516365...|+----+--------------------+...>>> model.findSynonymsArray("a", 2)

(continues on next page)

3.3. ML 547

Page 552: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

[('b', 0.015859870240092278), ('c', -0.5680795907974243)]>>> from pyspark.sql.functions import format_number as fmt>>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias(→˓"similarity")).show()+----+----------+|word|similarity|+----+----------+| b| 0.01586|| c| -0.56808|+----+----------+...>>> model.transform(doc).head().modelDenseVector([-0.4833, 0.1855, -0.273, -0.0509, -0.4769])>>> word2vecPath = temp_path + "/word2vec">>> word2Vec.save(word2vecPath)>>> loadedWord2Vec = Word2Vec.load(word2vecPath)>>> loadedWord2Vec.getVectorSize() == word2Vec.getVectorSize()True>>> loadedWord2Vec.getNumPartitions() == word2Vec.getNumPartitions()True>>> loadedWord2Vec.getMinCount() == word2Vec.getMinCount()True>>> modelPath = temp_path + "/word2vec-model">>> model.save(modelPath)>>> loadedModel = Word2VecModel.load(modelPath)>>> loadedModel.getVectors().first().word == model.getVectors().first().wordTrue>>> loadedModel.getVectors().first().vector == model.getVectors().first().vectorTrue>>> loadedModel.transform(doc).take(1) == model.transform(doc).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

continues on next page

548 Chapter 3. API Reference

Page 553: pyspark Documentation

pyspark Documentation, Release master

Table 185 – continued from previous pagefitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map

in paramMaps.getInputCol() Gets the value of inputCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxSentenceLength() Gets the value of maxSentenceLength or its default

value.getMinCount() Gets the value of minCount or its default value.getNumPartitions() Gets the value of numPartitions or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getVectorSize() Gets the value of vectorSize or its default value.getWindowSize() Gets the value of windowSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setMaxIter(value) Sets the value of maxIter.setMaxSentenceLength(value) Sets the value of maxSentenceLength.setMinCount(value) Sets the value of minCount.setNumPartitions(value) Sets the value of numPartitions.setOutputCol(value) Sets the value of outputCol.setParams(self, \*[, minCount, . . . ]) Sets params for this Word2Vec.setSeed(value) Sets the value of seed.setStepSize(value) Sets the value of stepSize.setVectorSize(value) Sets the value of vectorSize.setWindowSize(value) Sets the value of windowSize.write() Returns an MLWriter instance for this ML instance.

Attributes

inputColmaxItermaxSentenceLengthminCountnumPartitionsoutputCol

continues on next page

3.3. ML 549

Page 554: pyspark Documentation

pyspark Documentation, Release master

Table 186 – continued from previous pageparams Returns all params ordered by name.seedstepSizevectorSizewindowSize

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

550 Chapter 3. API Reference

Page 555: pyspark Documentation

pyspark Documentation, Release master

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getInputCol()Gets the value of inputCol or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMaxSentenceLength()Gets the value of maxSentenceLength or its default value.

New in version 2.0.0.

getMinCount()Gets the value of minCount or its default value.

New in version 1.4.0.

getNumPartitions()Gets the value of numPartitions or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getStepSize()Gets the value of stepSize or its default value.

getVectorSize()Gets the value of vectorSize or its default value.

New in version 1.4.0.

getWindowSize()Gets the value of windowSize or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 551

Page 556: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setMaxIter(value)Sets the value of maxIter.

setMaxSentenceLength(value)Sets the value of maxSentenceLength.

New in version 2.0.0.

setMinCount(value)Sets the value of minCount.

New in version 1.4.0.

setNumPartitions(value)Sets the value of numPartitions.

New in version 1.4.0.

setOutputCol(value)Sets the value of outputCol.

setParams(self, \*, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, input-Col=None, outputCol=None, windowSize=5, maxSentenceLength=1000)

Sets params for this Word2Vec.

New in version 1.4.0.

setSeed(value)Sets the value of seed.

setStepSize(value)Sets the value of stepSize.

New in version 1.4.0.

setVectorSize(value)Sets the value of vectorSize.

New in version 1.4.0.

552 Chapter 3. API Reference

Page 557: pyspark Documentation

pyspark Documentation, Release master

setWindowSize(value)Sets the value of windowSize.

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

maxSentenceLength = Param(parent='undefined', name='maxSentenceLength', doc='Maximum length (in words) of each sentence in the input data. Any sentence longer than this threshold will be divided into chunks up to the size.')

minCount = Param(parent='undefined', name='minCount', doc="the minimum number of times a token must appear to be included in the word2vec model's vocabulary")

numPartitions = Param(parent='undefined', name='numPartitions', doc='number of partitions for sentences of words')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

vectorSize = Param(parent='undefined', name='vectorSize', doc='the dimension of codes after transforming from words')

windowSize = Param(parent='undefined', name='windowSize', doc='the window size (context words from [-window, window]). Default value is 5')

Word2VecModel

class pyspark.ml.feature.Word2VecModel(java_model=None)Model fitted by Word2Vec.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

3.3. ML 553

Page 558: pyspark Documentation

pyspark Documentation, Release master

Table 187 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

findSynonyms(word, num) Find “num” number of words closest in similarity to“word”.

findSynonymsArray(word, num) Find “num” number of words closest in similarity to“word”.

getInputCol() Gets the value of inputCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxSentenceLength() Gets the value of maxSentenceLength or its default

value.getMinCount() Gets the value of minCount or its default value.getNumPartitions() Gets the value of numPartitions or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getOutputCol() Gets the value of outputCol or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getVectorSize() Gets the value of vectorSize or its default value.getVectors() Returns the vector representation of the words as a

dataframe with two fields, word and vector.getWindowSize() Gets the value of windowSize or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setInputCol(value) Sets the value of inputCol.setOutputCol(value) Sets the value of outputCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

554 Chapter 3. API Reference

Page 559: pyspark Documentation

pyspark Documentation, Release master

Attributes

inputColmaxItermaxSentenceLengthminCountnumPartitionsoutputColparams Returns all params ordered by name.seedstepSizevectorSizewindowSize

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

findSynonyms(word, num)Find “num” number of words closest in similarity to “word”. word can be a string or vector representation.Returns a dataframe with two fields word and similarity (which gives the cosine similarity).

New in version 1.5.0.

3.3. ML 555

Page 560: pyspark Documentation

pyspark Documentation, Release master

findSynonymsArray(word, num)Find “num” number of words closest in similarity to “word”. word can be a string or vector representation.Returns an array with two fields word and similarity (which gives the cosine similarity).

New in version 2.3.0.

getInputCol()Gets the value of inputCol or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMaxSentenceLength()Gets the value of maxSentenceLength or its default value.

New in version 2.0.0.

getMinCount()Gets the value of minCount or its default value.

New in version 1.4.0.

getNumPartitions()Gets the value of numPartitions or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getOutputCol()Gets the value of outputCol or its default value.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getStepSize()Gets the value of stepSize or its default value.

getVectorSize()Gets the value of vectorSize or its default value.

New in version 1.4.0.

getVectors()Returns the vector representation of the words as a dataframe with two fields, word and vector.

New in version 1.5.0.

getWindowSize()Gets the value of windowSize or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

556 Chapter 3. API Reference

Page 561: pyspark Documentation

pyspark Documentation, Release master

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setInputCol(value)Sets the value of inputCol.

setOutputCol(value)Sets the value of outputCol.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

inputCol = Param(parent='undefined', name='inputCol', doc='input column name.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

maxSentenceLength = Param(parent='undefined', name='maxSentenceLength', doc='Maximum length (in words) of each sentence in the input data. Any sentence longer than this threshold will be divided into chunks up to the size.')

minCount = Param(parent='undefined', name='minCount', doc="the minimum number of times a token must appear to be included in the word2vec model's vocabulary")

numPartitions = Param(parent='undefined', name='numPartitions', doc='number of partitions for sentences of words')

outputCol = Param(parent='undefined', name='outputCol', doc='output column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

vectorSize = Param(parent='undefined', name='vectorSize', doc='the dimension of codes after transforming from words')

3.3. ML 557

Page 562: pyspark Documentation

pyspark Documentation, Release master

windowSize = Param(parent='undefined', name='windowSize', doc='the window size (context words from [-window, window]). Default value is 5')

3.3.4 Classification

LinearSVC(*args, **kwargs) This binary classifier optimizes the Hinge Loss usingthe OWLQN optimizer.

LinearSVCModel([java_model]) Model fitted by LinearSVC.LinearSVCSummary([java_obj]) Abstraction for LinearSVC Results for a given model.LinearSVCTrainingSummary([java_obj]) Abstraction for LinearSVC Training results.LogisticRegression(*args, **kwargs) Logistic regression.LogisticRegressionModel([java_model]) Model fitted by LogisticRegression.LogisticRegressionSummary([java_obj]) Abstraction for Logistic Regression Results for a given

model.LogisticRegressionTrainingSummary([java_obj])Abstraction for multinomial Logistic Regression Train-

ing results.BinaryLogisticRegressionSummary([java_obj])Binary Logistic regression results for a given model.BinaryLogisticRegressionTrainingSummary([. . . ])Binary Logistic regression training results for a given

model.DecisionTreeClassifier(*args, **kwargs) Decision tree learning algorithm for classification.It

supports both binary and multiclass labels, as well asboth continuous and categorical features..

DecisionTreeClassificationModel([java_model])Model fitted by DecisionTreeClassifier.GBTClassifier(*args, **kwargs) Gradient-Boosted Trees (GBTs) learning algorithm for

classification.It supports binary labels, as well as bothcontinuous and categorical features..

GBTClassificationModel([java_model]) Model fitted by GBTClassifier.RandomForestClassifier(*args, **kwargs) Random Forest learning algorithm for classification.It

supports both binary and multiclass labels, as well asboth continuous and categorical features..

RandomForestClassificationModel([java_model])Model fitted by RandomForestClassifier.RandomForestClassificationSummary([java_obj])Abstraction for RandomForestClassification Results for

a given model.RandomForestClassificationTrainingSummary([. . . ])Abstraction for RandomForestClassificationTraining

Training results.BinaryRandomForestClassificationSummary([. . . ])BinaryRandomForestClassification results for a given

model.BinaryRandomForestClassificationTrainingSummary([. . . ])BinaryRandomForestClassification training results for a

given model.NaiveBayes(*args, **kwargs) Naive Bayes Classifiers.NaiveBayesModel([java_model]) Model fitted by NaiveBayes.MultilayerPerceptronClassifier(*args,**kwargs)

Classifier trainer based on the Multilayer Perceptron.

MultilayerPerceptronClassificationModel([. . . ])Model fitted by MultilayerPerceptronClassifier.MultilayerPerceptronClassificationSummary([. . . ])Abstraction for MultilayerPerceptronClassifier Results

for a given model.MultilayerPerceptronClassificationTrainingSummary([. . . ])Abstraction for MultilayerPerceptronClassifier Training

results.OneVsRest(*args, **kwargs) Reduction of Multiclass Classification to Binary Classi-

fication.OneVsRestModel(models) Model fitted by OneVsRest.

continues on next page

558 Chapter 3. API Reference

Page 563: pyspark Documentation

pyspark Documentation, Release master

Table 189 – continued from previous pageFMClassifier(*args, **kwargs) Factorization Machines learning algorithm for classifi-

cation.FMClassificationModel([java_model]) Model fitted by FMClassifier.FMClassificationSummary([java_obj]) Abstraction for FMClassifier Results for a given model.FMClassificationTrainingSummary([java_obj])Abstraction for FMClassifier Training results.

LinearSVC

class pyspark.ml.classification.LinearSVC(*args, **kwargs)This binary classifier optimizes the Hinge Loss using the OWLQN optimizer. Only supports L2 regularizationcurrently.

New in version 2.2.0.

Notes

Linear SVM Classifier

Examples

>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> df = sc.parallelize([... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()>>> svm = LinearSVC()>>> svm.getMaxIter()100>>> svm.setMaxIter(5)LinearSVC...>>> svm.getMaxIter()5>>> svm.getRegParam()0.0>>> svm.setRegParam(0.01)LinearSVC...>>> svm.getRegParam()0.01>>> model = svm.fit(df)>>> model.setPredictionCol("newPrediction")LinearSVCModel...>>> model.getPredictionCol()'newPrediction'>>> model.setThreshold(0.5)LinearSVCModel...>>> model.getThreshold()0.5>>> model.getMaxBlockSizeInMB()0.0>>> model.coefficientsDenseVector([0.0, -0.2792, -0.1833])>>> model.intercept1.0206118982229047

(continues on next page)

3.3. ML 559

Page 564: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.numClasses2>>> model.numFeatures3>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()>>> model.predict(test0.head().features)1.0>>> model.predictRaw(test0.head().features)DenseVector([-1.4831, 1.4831])>>> result = model.transform(test0).head()>>> result.newPrediction1.0>>> result.rawPredictionDenseVector([-1.4831, 1.4831])>>> svm_path = temp_path + "/svm">>> svm.save(svm_path)>>> svm2 = LinearSVC.load(svm_path)>>> svm2.getMaxIter()5>>> model_path = temp_path + "/svm_model">>> model.save(model_path)>>> model2 = LinearSVCModel.load(model_path)>>> model.coefficients[0] == model2.coefficients[0]True>>> model.intercept == model2.interceptTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

continues on next page

560 Chapter 3. API Reference

Page 565: pyspark Documentation

pyspark Documentation, Release master

Table 190 – continued from previous pagegetFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Gets the value of threshold or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-

belCol=”label”, predictionCol=”prediction”,maxIter=100, regParam=0.0, tol=1e-6, rawPredic-tionCol=”rawPrediction”, fitIntercept=True, stan-dardization=True, threshold=0.0, weightCol=None,aggregationDepth=2, maxBlockSizeInMB=0.0):Sets params for Linear SVM Classifier.

setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setRegParam(value) Sets the value of regParam.setStandardization(value) Sets the value of standardization.setThreshold(value) Sets the value of threshold.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

3.3. ML 561

Page 566: pyspark Documentation

pyspark Documentation, Release master

Attributes

aggregationDepthfeaturesColfitInterceptlabelColmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColrawPredictionColregParamstandardizationthresholdtolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

562 Chapter 3. API Reference

Page 567: pyspark Documentation

pyspark Documentation, Release master

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getStandardization()Gets the value of standardization or its default value.

3.3. ML 563

Page 568: pyspark Documentation

pyspark Documentation, Release master

getThreshold()Gets the value of threshold or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAggregationDepth(value)Sets the value of aggregationDepth.

New in version 2.2.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setFitIntercept(value)Sets the value of fitIntercept.

New in version 2.2.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.

New in version 3.1.0.

setMaxIter(value)Sets the value of maxIter.

New in version 2.2.0.

setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, maxIter=100,

564 Chapter 3. API Reference

Page 569: pyspark Documentation

pyspark Documentation, Release master

regParam=0.0, tol=1e-6, rawPredictionCol=”rawPrediction”, fitIntercept=True, standardization=True,threshold=0.0, weightCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0): Sets params for LinearSVM Classifier.

New in version 2.2.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setRegParam(value)Sets the value of regParam.

New in version 2.2.0.

setStandardization(value)Sets the value of standardization.

New in version 2.2.0.

setThreshold(value)Sets the value of threshold.

New in version 2.2.0.

setTol(value)Sets the value of tol.

New in version 2.2.0.

setWeightCol(value)Sets the value of weightCol.

New in version 2.2.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

3.3. ML 565

Page 570: pyspark Documentation

pyspark Documentation, Release master

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')

threshold = Param(parent='undefined', name='threshold', doc='The threshold in binary classification applied to the linear model prediction. This threshold can be any real number, where Inf will make all predictions 0.0 and -Inf will make all predictions 1.0.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

LinearSVCModel

class pyspark.ml.classification.LinearSVCModel(java_model=None)Model fitted by LinearSVC.

New in version 2.2.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Gets the value of threshold or its default value.getTol() Gets the value of tol or its default value.

continues on next page

566 Chapter 3. API Reference

Page 571: pyspark Documentation

pyspark Documentation, Release master

Table 192 – continued from previous pagegetWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThreshold(value) Sets the value of threshold.summary() Gets summary (e.g.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthcoefficients Model coefficients of Linear SVM Classifier.featuresColfitIntercepthasSummary Indicates whether a training summary exists for this

model instance.intercept Model intercept of Linear SVM Classifier.labelColmaxBlockSizeInMBmaxIternumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColrawPredictionColregParamstandardizationthresholdtolweightCol

3.3. ML 567

Page 572: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 3.1.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

568 Chapter 3. API Reference

Page 573: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getStandardization()Gets the value of standardization or its default value.

getThreshold()Gets the value of threshold or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

3.3. ML 569

Page 574: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThreshold(value)Sets the value of threshold.

New in version 3.0.0.

summary()Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.

New in version 3.1.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

coefficientsModel coefficients of Linear SVM Classifier.

New in version 2.2.0.

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

interceptModel intercept of Linear SVM Classifier.

New in version 2.2.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

570 Chapter 3. API Reference

Page 575: pyspark Documentation

pyspark Documentation, Release master

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')

threshold = Param(parent='undefined', name='threshold', doc='The threshold in binary classification applied to the linear model prediction. This threshold can be any real number, where Inf will make all predictions 0.0 and -Inf will make all predictions 1.0.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

LinearSVCSummary

class pyspark.ml.classification.LinearSVCSummary(java_obj=None)Abstraction for LinearSVC Results for a given model.

New in version 3.1.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.

continues on next page

3.3. ML 571

Page 576: pyspark Documentation

pyspark Documentation, Release master

Table 195 – continued from previous pagepr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

572 Chapter 3. API Reference

Page 577: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

3.3. ML 573

Page 578: pyspark Documentation

pyspark Documentation, Release master

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

LinearSVCTrainingSummary

class pyspark.ml.classification.LinearSVCTrainingSummary(java_obj=None)Abstraction for LinearSVC Training results.

New in version 3.1.0.

574 Chapter 3. API Reference

Page 579: pyspark Documentation

pyspark Documentation, Release master

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.pr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

3.3. ML 575

Page 580: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

576 Chapter 3. API Reference

Page 581: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

3.3. ML 577

Page 582: pyspark Documentation

pyspark Documentation, Release master

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

LogisticRegression

class pyspark.ml.classification.LogisticRegression(*args, **kwargs)Logistic regression. This class supports multinomial logistic (softmax) and binomial logistic regression.

New in version 1.3.0.

Examples

>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> bdf = sc.parallelize([... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()>>> blor = LogisticRegression(weightCol="weight")>>> blor.getRegParam()0.0>>> blor.setRegParam(0.01)LogisticRegression...>>> blor.getRegParam()0.01>>> blor.setMaxIter(10)LogisticRegression...>>> blor.getMaxIter()10>>> blor.clear(blor.maxIter)>>> blorModel = blor.fit(bdf)>>> blorModel.setFeaturesCol("features")LogisticRegressionModel...>>> blorModel.setProbabilityCol("newProbability")LogisticRegressionModel...>>> blorModel.getProbabilityCol()'newProbability'>>> blorModel.getMaxBlockSizeInMB()0.0>>> blorModel.setThreshold(0.1)

(continues on next page)

578 Chapter 3. API Reference

Page 583: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

LogisticRegressionModel...>>> blorModel.getThreshold()0.1>>> blorModel.coefficientsDenseVector([-1.080..., -0.646...])>>> blorModel.intercept3.112...>>> blorModel.evaluate(bdf).accuracy == blorModel.summary.accuracyTrue>>> data_path = "data/mllib/sample_multiclass_classification_data.txt">>> mdf = spark.read.format("libsvm").load(data_path)>>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family=→˓"multinomial")>>> mlorModel = mlor.fit(mdf)>>> mlorModel.coefficientMatrixSparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)>>> mlorModel.interceptVectorDenseVector([0.04..., -0.42..., 0.37...])>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()>>> blorModel.predict(test0.head().features)1.0>>> blorModel.predictRaw(test0.head().features)DenseVector([-3.54..., 3.54...])>>> blorModel.predictProbability(test0.head().features)DenseVector([0.028, 0.972])>>> result = blorModel.transform(test0).head()>>> result.prediction1.0>>> result.newProbabilityDenseVector([0.02..., 0.97...])>>> result.rawPredictionDenseVector([-3.54..., 3.54...])>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()>>> blorModel.transform(test1).head().prediction1.0>>> blor.setParams("vector")Traceback (most recent call last):

...TypeError: Method setParams forces keyword arguments.>>> lr_path = temp_path + "/lr">>> blor.save(lr_path)>>> lr2 = LogisticRegression.load(lr_path)>>> lr2.getRegParam()0.01>>> model_path = temp_path + "/lr_model">>> blorModel.save(model_path)>>> model2 = LogisticRegressionModel.load(model_path)>>> blorModel.coefficients[0] == model2.coefficients[0]True>>> blorModel.intercept == model2.interceptTrue>>> model2LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2>>> blorModel.transform(test0).take(1) == model2.transform(test0).take(1)True

3.3. ML 579

Page 584: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getElasticNetParam() Gets the value of elasticNetParam or its defaultvalue.

getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLowerBoundsOnCoefficients() Gets the value of

lowerBoundsOnCoefficientsgetLowerBoundsOnIntercepts() Gets the value of lowerBoundsOnInterceptsgetMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Get threshold for binary classification.getThresholds() If thresholds is set, return its value.getTol() Gets the value of tol or its default value.getUpperBoundsOnCoefficients() Gets the value of

upperBoundsOnCoefficientsgetUpperBoundsOnIntercepts() Gets the value of upperBoundsOnInterceptsgetWeightCol() Gets the value of weightCol or its default value.

continues on next page

580 Chapter 3. API Reference

Page 585: pyspark Documentation

pyspark Documentation, Release master

Table 198 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setElasticNetParam(value) Sets the value of elasticNetParam.setFamily(value) Sets the value of family .setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setLowerBoundsOnCoefficients(value) Sets the value of

lowerBoundsOnCoefficientssetLowerBoundsOnIntercepts(value) Sets the value of lowerBoundsOnInterceptssetMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-

belCol=”label”, predictionCol=”prediction”,maxIter=100, regParam=0.0, elasticNetParam=0.0,tol=1e-6, fitIntercept=True, threshold=0.5, thresh-olds=None, probabilityCol=”probability”, rawPre-dictionCol=”rawPrediction”, standardization=True,weightCol=None, aggregationDepth=2, fam-ily=”auto”, lowerBoundsOnCoefficients=None, up-perBoundsOnCoefficients=None, lowerBoundsOn-Intercepts=None, upperBoundsOnIntercepts=None,maxBlockSizeInMB=0.0): Sets params for logisticregression.

setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setRegParam(value) Sets the value of regParam.setStandardization(value) Sets the value of standardization.setThreshold(value) Sets the value of threshold.setThresholds(value) Sets the value of thresholds.setTol(value) Sets the value of tol.setUpperBoundsOnCoefficients(value) Sets the value of

upperBoundsOnCoefficientssetUpperBoundsOnIntercepts(value) Sets the value of upperBoundsOnInterceptssetWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

3.3. ML 581

Page 586: pyspark Documentation

pyspark Documentation, Release master

Attributes

aggregationDepthelasticNetParamfamilyfeaturesColfitInterceptlabelCollowerBoundsOnCoefficientslowerBoundsOnInterceptsmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParamstandardizationthresholdthresholdstolupperBoundsOnCoefficientsupperBoundsOnInterceptsweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

582 Chapter 3. API Reference

Page 587: pyspark Documentation

pyspark Documentation, Release master

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getElasticNetParam()Gets the value of elasticNetParam or its default value.

getFamily()Gets the value of family or its default value.

New in version 2.1.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getLowerBoundsOnCoefficients()Gets the value of lowerBoundsOnCoefficients

New in version 2.3.0.

3.3. ML 583

Page 588: pyspark Documentation

pyspark Documentation, Release master

getLowerBoundsOnIntercepts()Gets the value of lowerBoundsOnIntercepts

New in version 2.3.0.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getStandardization()Gets the value of standardization or its default value.

getThreshold()Get threshold for binary classification.

If thresholds is set with length 2 (i.e., binary classification), this returns the equivalent threshold:1

1+𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(0)𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(1)

. Otherwise, returns threshold if set or its default value if unset.

New in version 1.4.0.

getThresholds()If thresholds is set, return its value. Otherwise, if threshold is set, return the equivalent thresholdsfor binary classification: (1-threshold, threshold). If neither are set, throw an error.

New in version 1.5.0.

getTol()Gets the value of tol or its default value.

getUpperBoundsOnCoefficients()Gets the value of upperBoundsOnCoefficients

New in version 2.3.0.

getUpperBoundsOnIntercepts()Gets the value of upperBoundsOnIntercepts

New in version 2.3.0.

getWeightCol()Gets the value of weightCol or its default value.

584 Chapter 3. API Reference

Page 589: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAggregationDepth(value)Sets the value of aggregationDepth.

setElasticNetParam(value)Sets the value of elasticNetParam.

setFamily(value)Sets the value of family .

New in version 2.1.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setFitIntercept(value)Sets the value of fitIntercept.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLowerBoundsOnCoefficients(value)Sets the value of lowerBoundsOnCoefficients

New in version 2.3.0.

setLowerBoundsOnIntercepts(value)Sets the value of lowerBoundsOnIntercepts

New in version 2.3.0.

setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.

New in version 3.1.0.

setMaxIter(value)Sets the value of maxIter.

3.3. ML 585

Page 590: pyspark Documentation

pyspark Documentation, Release master

setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, maxIter=100,regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None,probabilityCol=”probability”, rawPredictionCol=”rawPrediction”, standardization=True, weight-Col=None, aggregationDepth=2, family=”auto”, lowerBoundsOnCoefficients=None, upperBoundsOn-Coefficients=None, lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, maxBlock-SizeInMB=0.0): Sets params for logistic regression. If the threshold and thresholds Params are both set,they must be equivalent.

New in version 1.3.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setRegParam(value)Sets the value of regParam.

setStandardization(value)Sets the value of standardization.

setThreshold(value)Sets the value of threshold. Clears value of thresholds if it has been set.

New in version 1.4.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setTol(value)Sets the value of tol.

setUpperBoundsOnCoefficients(value)Sets the value of upperBoundsOnCoefficients

New in version 2.3.0.

setUpperBoundsOnIntercepts(value)Sets the value of upperBoundsOnIntercepts

New in version 2.3.0.

setWeightCol(value)Sets the value of weightCol.

write()Returns an MLWriter instance for this ML instance.

586 Chapter 3. API Reference

Page 591: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')

family = Param(parent='undefined', name='family', doc='The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

lowerBoundsOnCoefficients = Param(parent='undefined', name='lowerBoundsOnCoefficients', doc='The lower bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')

lowerBoundsOnIntercepts = Param(parent='undefined', name='lowerBoundsOnIntercepts', doc='The lower bounds on intercepts if fitting under bound constrained optimization. The bounds vector size must beequal with 1 for binomial regression, or the number oflasses for multinomial regression.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')

threshold = Param(parent='undefined', name='threshold', doc='Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p].')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

upperBoundsOnCoefficients = Param(parent='undefined', name='upperBoundsOnCoefficients', doc='The upper bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')

upperBoundsOnIntercepts = Param(parent='undefined', name='upperBoundsOnIntercepts', doc='The upper bounds on intercepts if fitting under bound constrained optimization. The bound vector size must be equal with 1 for binomial regression, or the number of classes for multinomial regression.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

LogisticRegressionModel

class pyspark.ml.classification.LogisticRegressionModel(java_model=None)Model fitted by LogisticRegression.

New in version 1.3.0.

3.3. ML 587

Page 592: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getElasticNetParam() Gets the value of elasticNetParam or its defaultvalue.

getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLowerBoundsOnCoefficients() Gets the value of

lowerBoundsOnCoefficientsgetLowerBoundsOnIntercepts() Gets the value of lowerBoundsOnInterceptsgetMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getRegParam() Gets the value of regParam or its default value.getStandardization() Gets the value of standardization or its default value.getThreshold() Get threshold for binary classification.getThresholds() If thresholds is set, return its value.getTol() Gets the value of tol or its default value.getUpperBoundsOnCoefficients() Gets the value of

upperBoundsOnCoefficientsgetUpperBoundsOnIntercepts() Gets the value of upperBoundsOnInterceptsgetWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.continues on next page

588 Chapter 3. API Reference

Page 593: pyspark Documentation

pyspark Documentation, Release master

Table 200 – continued from previous pageisDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThreshold(value) Sets the value of threshold.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthcoefficientMatrix Model coefficients.coefficients Model coefficients of binomial logistic regression.elasticNetParamfamilyfeaturesColfitIntercepthasSummary Indicates whether a training summary exists for this

model instance.intercept Model intercept of binomial logistic regression.interceptVector Model intercept.labelCollowerBoundsOnCoefficientslowerBoundsOnInterceptsmaxBlockSizeInMBmaxIternumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParam

continues on next page

3.3. ML 589

Page 594: pyspark Documentation

pyspark Documentation, Release master

Table 201 – continued from previous pagestandardizationsummary Gets summary (e.g.thresholdthresholdstolupperBoundsOnCoefficientsupperBoundsOnInterceptsweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 2.0.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getElasticNetParam()Gets the value of elasticNetParam or its default value.

590 Chapter 3. API Reference

Page 595: pyspark Documentation

pyspark Documentation, Release master

getFamily()Gets the value of family or its default value.

New in version 2.1.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getLowerBoundsOnCoefficients()Gets the value of lowerBoundsOnCoefficients

New in version 2.3.0.

getLowerBoundsOnIntercepts()Gets the value of lowerBoundsOnIntercepts

New in version 2.3.0.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getStandardization()Gets the value of standardization or its default value.

getThreshold()Get threshold for binary classification.

If thresholds is set with length 2 (i.e., binary classification), this returns the equivalent threshold:1

1+𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(0)𝑡ℎ𝑟𝑒𝑠ℎ𝑜𝑙𝑑𝑠(1)

. Otherwise, returns threshold if set or its default value if unset.

New in version 1.4.0.

getThresholds()If thresholds is set, return its value. Otherwise, if threshold is set, return the equivalent thresholdsfor binary classification: (1-threshold, threshold). If neither are set, throw an error.

3.3. ML 591

Page 596: pyspark Documentation

pyspark Documentation, Release master

New in version 1.5.0.

getTol()Gets the value of tol or its default value.

getUpperBoundsOnCoefficients()Gets the value of upperBoundsOnCoefficients

New in version 2.3.0.

getUpperBoundsOnIntercepts()Gets the value of upperBoundsOnIntercepts

New in version 2.3.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

592 Chapter 3. API Reference

Page 597: pyspark Documentation

pyspark Documentation, Release master

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThreshold(value)Sets the value of threshold. Clears value of thresholds if it has been set.

New in version 1.4.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

coefficientMatrixModel coefficients.

New in version 2.1.0.

coefficientsModel coefficients of binomial logistic regression. An exception is thrown in the case of multinomiallogistic regression.

New in version 2.0.0.

elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')

family = Param(parent='undefined', name='family', doc='The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

3.3. ML 593

Page 598: pyspark Documentation

pyspark Documentation, Release master

interceptModel intercept of binomial logistic regression. An exception is thrown in the case of multinomial logisticregression.

New in version 1.4.0.

interceptVectorModel intercept.

New in version 2.1.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

lowerBoundsOnCoefficients = Param(parent='undefined', name='lowerBoundsOnCoefficients', doc='The lower bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')

lowerBoundsOnIntercepts = Param(parent='undefined', name='lowerBoundsOnIntercepts', doc='The lower bounds on intercepts if fitting under bound constrained optimization. The bounds vector size must beequal with 1 for binomial regression, or the number oflasses for multinomial regression.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')

summaryGets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.

New in version 2.0.0.

threshold = Param(parent='undefined', name='threshold', doc='Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p].')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

upperBoundsOnCoefficients = Param(parent='undefined', name='upperBoundsOnCoefficients', doc='The upper bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression.')

upperBoundsOnIntercepts = Param(parent='undefined', name='upperBoundsOnIntercepts', doc='The upper bounds on intercepts if fitting under bound constrained optimization. The bound vector size must be equal with 1 for binomial regression, or the number of classes for multinomial regression.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

594 Chapter 3. API Reference

Page 599: pyspark Documentation

pyspark Documentation, Release master

LogisticRegressionSummary

class pyspark.ml.classification.LogisticRegressionSummary(java_obj=None)Abstraction for Logistic Regression Results for a given model.

New in version 2.0.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of

each instance as a vector.labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.probabilityCol Field in “predictions” which gives the probability of

each class as a vector.recallByLabel Returns recall for each label (category).truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

3.3. ML 595

Page 600: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

featuresColField in “predictions” which gives the features of each instance as a vector.

New in version 2.0.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

probabilityColField in “predictions” which gives the probability of each class as a vector.

New in version 2.0.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

596 Chapter 3. API Reference

Page 601: pyspark Documentation

pyspark Documentation, Release master

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

LogisticRegressionTrainingSummary

class pyspark.ml.classification.LogisticRegressionTrainingSummary(java_obj=None)Abstraction for multinomial Logistic Regression Training results.

New in version 2.0.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of

each instance as a vector.labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.continues on next page

3.3. ML 597

Page 602: pyspark Documentation

pyspark Documentation, Release master

Table 205 – continued from previous pageprobabilityCol Field in “predictions” which gives the probability of

each class as a vector.recallByLabel Returns recall for each label (category).totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

featuresColField in “predictions” which gives the features of each instance as a vector.

New in version 2.0.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

598 Chapter 3. API Reference

Page 603: pyspark Documentation

pyspark Documentation, Release master

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

probabilityColField in “predictions” which gives the probability of each class as a vector.

New in version 2.0.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

3.3. ML 599

Page 604: pyspark Documentation

pyspark Documentation, Release master

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

BinaryLogisticRegressionSummary

class pyspark.ml.classification.BinaryLogisticRegressionSummary(java_obj=None)Binary Logistic regression results for a given model.

New in version 2.0.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of

each instance as a vector.labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.pr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.probabilityCol Field in “predictions” which gives the probability of

each class as a vector.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

continues on next page

600 Chapter 3. API Reference

Page 605: pyspark Documentation

pyspark Documentation, Release master

Table 207 – continued from previous pagescoreCol Field in “predictions” which gives the probability or

raw prediction of each class as a vector.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

featuresColField in “predictions” which gives the features of each instance as a vector.

New in version 2.0.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

3.3. ML 601

Page 606: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

probabilityColField in “predictions” which gives the probability of each class as a vector.

New in version 2.0.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

602 Chapter 3. API Reference

Page 607: pyspark Documentation

pyspark Documentation, Release master

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

BinaryLogisticRegressionTrainingSummary

class pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary(java_obj=None)Binary Logistic regression training results for a given model.

New in version 2.0.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

3.3. ML 603

Page 608: pyspark Documentation

pyspark Documentation, Release master

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).featuresCol Field in “predictions” which gives the features of

each instance as a vector.labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.pr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.probabilityCol Field in “predictions” which gives the probability of

each class as a vector.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

604 Chapter 3. API Reference

Page 609: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

featuresColField in “predictions” which gives the features of each instance as a vector.

New in version 2.0.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

3.3. ML 605

Page 610: pyspark Documentation

pyspark Documentation, Release master

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

probabilityColField in “predictions” which gives the probability of each class as a vector.

New in version 2.0.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

606 Chapter 3. API Reference

Page 611: pyspark Documentation

pyspark Documentation, Release master

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

DecisionTreeClassifier

class pyspark.ml.classification.DecisionTreeClassifier(*args, **kwargs)Decision tree learning algorithm for classification. It supports both binary and multiclass labels, as well as bothcontinuous and categorical features.

New in version 1.4.0.

3.3. ML 607

Page 612: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.feature import StringIndexer>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model = stringIndexer.fit(df)>>> td = si_model.transform(df)>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")>>> model = dt.fit(td)>>> model.getLabelCol()'indexed'>>> model.setFeaturesCol("features")DecisionTreeClassificationModel...>>> model.numNodes3>>> model.depth1>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> model.numFeatures1>>> model.numClasses2>>> print(model.toDebugString)DecisionTreeClassificationModel...depth=1, numNodes=3...>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictRaw(test0.head().features)DenseVector([1.0, 0.0])>>> model.predictProbability(test0.head().features)DenseVector([1.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.probabilityDenseVector([1.0, 0.0])>>> result.rawPredictionDenseVector([1.0, 0.0])>>> result.leafId0.0>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> dtc_path = temp_path + "/dtc">>> dt.save(dtc_path)>>> dt2 = DecisionTreeClassifier.load(dtc_path)>>> dt2.getMaxDepth()2>>> model_path = temp_path + "/dtc_model">>> model.save(model_path)>>> model2 = DecisionTreeClassificationModel.load(model_path)>>> model.featureImportances == model2.featureImportances

(continues on next page)

608 Chapter 3. API Reference

Page 613: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

True>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> df3 = spark.createDataFrame([... (1.0, 0.2, Vectors.dense(1.0)),... (1.0, 0.8, Vectors.dense(1.0)),... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])>>> si3 = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model3 = si3.fit(df3)>>> td3 = si_model3.transform(df3)>>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed→˓")>>> model3 = dt3.fit(td3)>>> print(model3.toDebugString)DecisionTreeClassificationModel...depth=1, numNodes=3...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.

continues on next page

3.3. ML 609

Page 614: pyspark Documentation

pyspark Documentation, Release master

Table 210 – continued from previous pagegetMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of

minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for the DecisionTreeClassifier.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setThresholds(value) Sets the value of thresholds.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

610 Chapter 3. API Reference

Page 615: pyspark Documentation

pyspark Documentation, Release master

Attributes

cacheNodeIdscheckpointIntervalfeaturesColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsupportedImpuritiesthresholdsweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

3.3. ML 611

Page 616: pyspark Documentation

pyspark Documentation, Release master

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.6.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

612 Chapter 3. API Reference

Page 617: pyspark Documentation

pyspark Documentation, Release master

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getThresholds()Gets the value of thresholds or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCacheNodeIds(value)Sets the value of cacheNodeIds.

setCheckpointInterval(value)Sets the value of checkpointInterval.

3.3. ML 613

Page 618: pyspark Documentation

pyspark Documentation, Release master

New in version 1.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setImpurity(value)Sets the value of impurity .

New in version 1.4.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setMaxBins(value)Sets the value of maxBins.

setMaxDepth(value)Sets the value of maxDepth.

setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.

setMinInfoGain(value)Sets the value of minInfoGain.

setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.

setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.

New in version 3.0.0.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5,maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256,cacheNodeIds=False, checkpointInterval=10, impurity="gini", seed=None, weight-Col=None, leafCol="", minWeightFractionPerNode=0.0)

Sets params for the DecisionTreeClassifier.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

614 Chapter 3. API Reference

Page 619: pyspark Documentation

pyspark Documentation, Release master

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

supportedImpurities = ['entropy', 'gini']

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 615

Page 620: pyspark Documentation

pyspark Documentation, Release master

DecisionTreeClassificationModel

class pyspark.ml.classification.DecisionTreeClassificationModel(java_model=None)Model fitted by DecisionTreeClassifier.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.

continues on next page

616 Chapter 3. API Reference

Page 621: pyspark Documentation

pyspark Documentation, Release master

Table 212 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the

feature vector.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

cacheNodeIdscheckpointIntervaldepth Return depth of the decision tree.featureImportances Estimate of the importance of each feature.featuresColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.numNodes Return number of nodes of the decision tree.params Returns all params ordered by name.predictionCol

continues on next page

3.3. ML 617

Page 622: pyspark Documentation

pyspark Documentation, Release master

Table 213 – continued from previous pageprobabilityColrawPredictionColseedsupportedImpuritiesthresholdstoDebugString Full description of model.weightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.6.0.

618 Chapter 3. API Reference

Page 623: pyspark Documentation

pyspark Documentation, Release master

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getThresholds()Gets the value of thresholds or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 619

Page 624: pyspark Documentation

pyspark Documentation, Release master

predict(value)Predict label for the given features.

New in version 3.0.0.

predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.

New in version 3.0.0.

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

620 Chapter 3. API Reference

Page 625: pyspark Documentation

pyspark Documentation, Release master

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

depthReturn depth of the decision tree.

New in version 1.5.0.

featureImportancesEstimate of the importance of each feature.

This generalizes the idea of “Gini” importance to other losses, following the explanation of Gini im-portance from “Random Forests” documentation by Leo Breiman and Adele Cutler, and following theimplementation from scikit-learn.

This feature importance is calculated as follows:

• importance(feature j) = sum (over nodes which split on feature j) of the gain, where gain is scaledby the number of instances passing through node

• Normalize importances for tree to sum to 1.

New in version 2.0.0.

Notes

Feature importance for single decision trees can have high variance due to correlated predictor variables.Consider using a RandomForestClassifier to determine feature importance instead.

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

3.3. ML 621

Page 626: pyspark Documentation

pyspark Documentation, Release master

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

numNodesReturn number of nodes of the decision tree.

New in version 1.5.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

supportedImpurities = ['entropy', 'gini']

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

toDebugStringFull description of model.

New in version 2.0.0.

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

GBTClassifier

class pyspark.ml.classification.GBTClassifier(*args, **kwargs)Gradient-Boosted Trees (GBTs) learning algorithm for classification. It supports binary labels, as well as bothcontinuous and categorical features.

New in version 1.4.0.

Notes

Multiclass labels are not currently supported.

The implementation is based upon: J.H. Friedman. “Stochastic Gradient Boosting.” 1999.

Gradient Boosting vs. TreeBoost:

• This implementation is for Stochastic Gradient Boosting, not for TreeBoost.

• Both algorithms learn tree ensembles by minimizing loss functions.

• TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function,whereas the original gradient boosting method does not.

• We expect to implement TreeBoost in the future: SPARK-4240

622 Chapter 3. API Reference

Page 627: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.feature import StringIndexer>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model = stringIndexer.fit(df)>>> td = si_model.transform(df)>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,... leafCol="leafId")>>> gbt.setMaxIter(5)GBTClassifier...>>> gbt.setMinWeightFractionPerNode(0.049)GBTClassifier...>>> gbt.getMaxIter()5>>> gbt.getFeatureSubsetStrategy()'all'>>> model = gbt.fit(td)>>> model.getLabelCol()'indexed'>>> model.setFeaturesCol("features")GBTClassificationModel...>>> model.setThresholds([0.3, 0.7])GBTClassificationModel...>>> model.getThresholds()[0.3, 0.7]>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictRaw(test0.head().features)DenseVector([1.1697, -1.1697])>>> model.predictProbability(test0.head().features)DenseVector([0.9121, 0.0879])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.leafIdDenseVector([0.0, 0.0, 0.0, 0.0, 0.0])>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> model.totalNumNodes15>>> print(model.toDebugString)GBTClassificationModel...numTrees=5...>>> gbtc_path = temp_path + "gbtc">>> gbt.save(gbtc_path)>>> gbt2 = GBTClassifier.load(gbtc_path)

(continues on next page)

3.3. ML 623

Page 628: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> gbt2.getMaxDepth()2>>> model_path = temp_path + "gbtc_model">>> model.save(model_path)>>> model2 = GBTClassificationModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.treeWeights == model2.treeWeightsTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.trees[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],... ["indexed", "features"])>>> model.evaluateEachIteration(validation)[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]>>> model.numClasses2>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")>>> gbt.getValidationIndicatorCol()'validationIndicator'>>> gbt.getValidationTol()0.01

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.continues on next page

624 Chapter 3. API Reference

Page 629: pyspark Documentation

pyspark Documentation, Release master

Table 214 – continued from previous pagegetFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getThresholds() Gets the value of thresholds or its default value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default

value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setLossType(value) Sets the value of lossType.

continues on next page

3.3. ML 625

Page 630: pyspark Documentation

pyspark Documentation, Release master

Table 214 – continued from previous pagesetMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxIter(value) Sets the value of maxIter.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of

minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for Gradient Boosted Tree Classifica-

tion.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setStepSize(value) Sets the value of stepSize.setSubsamplingRate(value) Sets the value of subsamplingRate.setThresholds(value) Sets the value of thresholds.setValidationIndicatorCol(value) Sets the value of validationIndicatorCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

cacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpuritylabelColleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypesthresholds

continues on next page

626 Chapter 3. API Reference

Page 631: pyspark Documentation

pyspark Documentation, Release master

Table 215 – continued from previous pagevalidationIndicatorColvalidationTolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

3.3. ML 627

Page 632: pyspark Documentation

pyspark Documentation, Release master

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getLossType()Gets the value of lossType or its default value.

New in version 1.4.0.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

628 Chapter 3. API Reference

Page 633: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getStepSize()Gets the value of stepSize or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getThresholds()Gets the value of thresholds or its default value.

getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.

getValidationTol()Gets the value of validationTol or its default value.

New in version 3.0.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

3.3. ML 629

Page 634: pyspark Documentation

pyspark Documentation, Release master

setCacheNodeIds(value)Sets the value of cacheNodeIds.

setCheckpointInterval(value)Sets the value of checkpointInterval.

New in version 1.4.0.

setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setImpurity(value)Sets the value of impurity .

New in version 1.4.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setLossType(value)Sets the value of lossType.

New in version 1.4.0.

setMaxBins(value)Sets the value of maxBins.

setMaxDepth(value)Sets the value of maxDepth.

setMaxIter(value)Sets the value of maxIter.

New in version 1.4.0.

setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.

setMinInfoGain(value)Sets the value of minInfoGain.

setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.

setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.

New in version 3.0.0.

630 Chapter 3. API Reference

Page 635: pyspark Documentation

pyspark Documentation, Release master

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance", fea-tureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None, leafCol="",minWeightFractionPerNode=0.0, weightCol=None)

Sets params for Gradient Boosted Tree Classification.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

New in version 1.4.0.

setStepSize(value)Sets the value of stepSize.

New in version 1.4.0.

setSubsamplingRate(value)Sets the value of subsamplingRate.

New in version 1.4.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setValidationIndicatorCol(value)Sets the value of validationIndicatorCol.

New in version 3.0.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

3.3. ML 631

Page 636: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: logistic')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['variance']

supportedLossTypes = ['logistic']

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')

validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

632 Chapter 3. API Reference

Page 637: pyspark Documentation

pyspark Documentation, Release master

GBTClassificationModel

class pyspark.ml.classification.GBTClassificationModel(java_model=None)Model fitted by GBTClassifier.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluateEachIteration(dataset) Method to compute error or loss for every iterationof gradient boosting.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.

continues on next page

3.3. ML 633

Page 638: pyspark Documentation

pyspark Documentation, Release master

Table 216 – continued from previous pagegetRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getThresholds() Gets the value of thresholds or its default value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default

value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the

feature vector.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

cacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.impuritylabelCol

continues on next page

634 Chapter 3. API Reference

Page 639: pyspark Documentation

pyspark Documentation, Release master

Table 217 – continued from previous pageleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypesthresholdstoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the

ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.validationIndicatorColvalidationTolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluateEachIteration(dataset)Method to compute error or loss for every iteration of gradient boosting.

New in version 2.4.0.

3.3. ML 635

Page 640: pyspark Documentation

pyspark Documentation, Release master

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getLossType()Gets the value of lossType or its default value.

New in version 1.4.0.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

636 Chapter 3. API Reference

Page 641: pyspark Documentation

pyspark Documentation, Release master

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getStepSize()Gets the value of stepSize or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getThresholds()Gets the value of thresholds or its default value.

getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.

getValidationTol()Gets the value of validationTol or its default value.

New in version 3.0.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 637

Page 642: pyspark Documentation

pyspark Documentation, Release master

predict(value)Predict label for the given features.

New in version 3.0.0.

predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.

New in version 3.0.0.

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

638 Chapter 3. API Reference

Page 643: pyspark Documentation

pyspark Documentation, Release master

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureImportancesEstimate of the importance of each feature.

Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.

New in version 2.0.0.

See also:

DecisionTreeClassificationModel.featureImportances

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

getNumTreesNumber of trees in ensemble.

New in version 2.0.0.

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: logistic')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

3.3. ML 639

Page 644: pyspark Documentation

pyspark Documentation, Release master

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['variance']

supportedLossTypes = ['logistic']

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

toDebugStringFull description of model.

New in version 2.0.0.

totalNumNodesTotal number of nodes, summed over all trees in the ensemble.

New in version 2.0.0.

treeWeightsReturn the weights for each tree

New in version 1.5.0.

treesTrees in this ensemble. Warning: These have null parent Estimators.

New in version 2.0.0.

validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')

validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

640 Chapter 3. API Reference

Page 645: pyspark Documentation

pyspark Documentation, Release master

RandomForestClassifier

class pyspark.ml.classification.RandomForestClassifier(*args, **kwargs)Random Forest learning algorithm for classification. It supports both binary and multiclass labels, as well asboth continuous and categorical features.

New in version 1.4.0.

Examples

>>> import numpy>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.feature import StringIndexer>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")>>> si_model = stringIndexer.fit(df)>>> td = si_model.transform(df)>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed",→˓seed=42,... leafCol="leafId")>>> rf.getMinWeightFractionPerNode()0.0>>> model = rf.fit(td)>>> model.getLabelCol()'indexed'>>> model.setFeaturesCol("features")RandomForestClassificationModel...>>> model.setRawPredictionCol("newRawPrediction")RandomForestClassificationModel...>>> model.getBootstrap()True>>> model.getRawPredictionCol()'newRawPrediction'>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictRaw(test0.head().features)DenseVector([2.0, 0.0])>>> model.predictProbability(test0.head().features)DenseVector([1.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> numpy.argmax(result.probability)0>>> numpy.argmax(result.newRawPrediction)0>>> result.leafIdDenseVector([0.0, 0.0, 0.0])>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"]) (continues on next page)

3.3. ML 641

Page 646: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.transform(test1).head().prediction1.0>>> model.trees[DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]>>> rfc_path = temp_path + "/rfc">>> rf.save(rfc_path)>>> rf2 = RandomForestClassifier.load(rfc_path)>>> rf2.getNumTrees()3>>> model_path = temp_path + "/rfc_model">>> model.save(model_path)>>> model2 = RandomForestClassificationModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.

continues on next page

642 Chapter 3. API Reference

Page 647: pyspark Documentation

pyspark Documentation, Release master

Table 218 – continued from previous pagegetMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getNumTrees() Gets the value of numTrees or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBootstrap(value) Sets the value of bootstrap.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of

minWeightFractionPerNode.setNumTrees(value) Sets the value of numTrees.setParams(self[, featuresCol, labelCol, . . . ]) Sets params for linear classification.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.

continues on next page

3.3. ML 643

Page 648: pyspark Documentation

pyspark Documentation, Release master

Table 218 – continued from previous pagesetRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setSubsamplingRate(value) Sets the value of subsamplingRate.setThresholds(value) Sets the value of thresholds.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

bootstrapcacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumTreesparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiesthresholdsweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

644 Chapter 3. API Reference

Page 649: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getBootstrap()Gets the value of bootstrap or its default value.

New in version 3.0.0.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

3.3. ML 645

Page 650: pyspark Documentation

pyspark Documentation, Release master

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.6.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getNumTrees()Gets the value of numTrees or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

646 Chapter 3. API Reference

Page 651: pyspark Documentation

pyspark Documentation, Release master

getThresholds()Gets the value of thresholds or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBootstrap(value)Sets the value of bootstrap.

New in version 3.0.0.

setCacheNodeIds(value)Sets the value of cacheNodeIds.

setCheckpointInterval(value)Sets the value of checkpointInterval.

setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setImpurity(value)Sets the value of impurity .

New in version 1.4.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

3.3. ML 647

Page 652: pyspark Documentation

pyspark Documentation, Release master

setMaxBins(value)Sets the value of maxBins.

setMaxDepth(value)Sets the value of maxDepth.

setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.

setMinInfoGain(value)Sets the value of minInfoGain.

setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.

setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.

New in version 3.0.0.

setNumTrees(value)Sets the value of numTrees.

New in version 1.4.0.

setParams(self, featuresCol='features', labelCol='label', predictionCol='prediction', probability-Col='probability', rawPredictionCol='rawPrediction', maxDepth=5, maxBins=32, minIn-stancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False,checkpointInterval=10, seed=None, impurity='gini', numTrees=20, featureSubsetStrat-egy='auto', subsamplingRate=1.0, leafCol='', minWeightFractionPerNode=0.0, weight-Col=None, bootstrap=True)

Sets params for linear classification.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

setSubsamplingRate(value)Sets the value of subsamplingRate.

New in version 1.4.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setWeightCol(value)Sets the value of weightCol.

648 Chapter 3. API Reference

Page 653: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['entropy', 'gini']

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 649

Page 654: pyspark Documentation

pyspark Documentation, Release master

RandomForestClassificationModel

class pyspark.ml.classification.RandomForestClassificationModel(java_model=None)Model fitted by RandomForestClassifier.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.continues on next page

650 Chapter 3. API Reference

Page 655: pyspark Documentation

pyspark Documentation, Release master

Table 220 – continued from previous pagegetSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the

feature vector.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

bootstrapcacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.hasSummary Indicates whether a training summary exists for this

model instance.impuritylabelColleafColmaxBinsmaxDepth

continues on next page

3.3. ML 651

Page 656: pyspark Documentation

pyspark Documentation, Release master

Table 221 – continued from previous pagemaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.numTreesparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsubsamplingRatesummary Gets summary (e.g.supportedFeatureSubsetStrategiessupportedImpuritiesthresholdstoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the

ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.weightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 3.1.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

652 Chapter 3. API Reference

Page 657: pyspark Documentation

pyspark Documentation, Release master

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBootstrap()Gets the value of bootstrap or its default value.

New in version 3.0.0.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.6.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

3.3. ML 653

Page 658: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getThresholds()Gets the value of thresholds or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.

New in version 3.0.0.

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

654 Chapter 3. API Reference

Page 659: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 655

Page 660: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureImportancesEstimate of the importance of each feature.

Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.

New in version 2.0.0.

See also:

DecisionTreeClassificationModel.featureImportances

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

getNumTreesNumber of trees in ensemble.

New in version 2.0.0.

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: entropy, gini')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')

656 Chapter 3. API Reference

Page 661: pyspark Documentation

pyspark Documentation, Release master

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

summaryGets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.

New in version 3.1.0.

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['entropy', 'gini']

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

toDebugStringFull description of model.

New in version 2.0.0.

totalNumNodesTotal number of nodes, summed over all trees in the ensemble.

New in version 2.0.0.

treeWeightsReturn the weights for each tree

New in version 1.5.0.

treesTrees in this ensemble. Warning: These have null parent Estimators.

New in version 2.0.0.

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

RandomForestClassificationSummary

class pyspark.ml.classification.RandomForestClassificationSummary(java_obj=None)Abstraction for RandomForestClassification Results for a given model.

New in version 3.1.0.

3.3. ML 657

Page 662: pyspark Documentation

pyspark Documentation, Release master

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

658 Chapter 3. API Reference

Page 663: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

3.3. ML 659

Page 664: pyspark Documentation

pyspark Documentation, Release master

RandomForestClassificationTrainingSummary

class pyspark.ml.classification.RandomForestClassificationTrainingSummary(java_obj=None)Abstraction for RandomForestClassificationTraining Training results.

New in version 3.1.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

660 Chapter 3. API Reference

Page 665: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

3.3. ML 661

Page 666: pyspark Documentation

pyspark Documentation, Release master

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

BinaryRandomForestClassificationSummary

class pyspark.ml.classification.BinaryRandomForestClassificationSummary(java_obj=None)BinaryRandomForestClassification results for a given model.

New in version 3.1.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.

continues on next page

662 Chapter 3. API Reference

Page 667: pyspark Documentation

pyspark Documentation, Release master

Table 227 – continued from previous pagepr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

3.3. ML 663

Page 668: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

664 Chapter 3. API Reference

Page 669: pyspark Documentation

pyspark Documentation, Release master

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

BinaryRandomForestClassificationTrainingSummary

class pyspark.ml.classification.BinaryRandomForestClassificationTrainingSummary(java_obj=None)BinaryRandomForestClassification training results for a given model.

New in version 3.1.0.

3.3. ML 665

Page 670: pyspark Documentation

pyspark Documentation, Release master

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.pr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

666 Chapter 3. API Reference

Page 671: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

3.3. ML 667

Page 672: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

668 Chapter 3. API Reference

Page 673: pyspark Documentation

pyspark Documentation, Release master

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

NaiveBayes

class pyspark.ml.classification.NaiveBayes(*args, **kwargs)Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB can handle finitelysupported discrete data. For example, by converting documents into TF-IDF vectors, it can be used for documentclassification. By making every vector a binary (0/1) data, it can also be used as Bernoulli NB.

The input feature values for Multinomial NB and Bernoulli NB must be nonnegative. Since 3.0.0, it supportsComplement NB which is an adaptation of the Multinomial NB. Specifically, Complement NB uses statisticsfrom the complement of each class to compute the model’s coefficients. The inventors of Complement NBshow empirically that the parameter estimates for CNB are more stable than those for Multinomial NB. LikeMultinomial NB, the input feature values for Complement NB must be nonnegative. Since 3.0.0, it also supportsGaussian NB. which can handle continuous data.

New in version 1.5.0.

Examples

>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")>>> model = nb.fit(df)>>> model.setFeaturesCol("features")NaiveBayesModel...>>> model.getSmoothing()1.0>>> model.piDenseVector([-0.81..., -0.58...])>>> model.thetaDenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)>>> model.sigmaDenseMatrix(0, 0, [...], ...)>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()>>> model.predict(test0.head().features)

(continues on next page)

3.3. ML 669

Page 674: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

1.0>>> model.predictRaw(test0.head().features)DenseVector([-1.72..., -0.99...])>>> model.predictProbability(test0.head().features)DenseVector([0.32..., 0.67...])>>> result = model.transform(test0).head()>>> result.prediction1.0>>> result.probabilityDenseVector([0.32..., 0.67...])>>> result.rawPredictionDenseVector([-1.72..., -0.99...])>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()>>> model.transform(test1).head().prediction1.0>>> nb_path = temp_path + "/nb">>> nb.save(nb_path)>>> nb2 = NaiveBayes.load(nb_path)>>> nb2.getSmoothing()1.0>>> model_path = temp_path + "/nb_model">>> model.save(model_path)>>> model2 = NaiveBayesModel.load(model_path)>>> model.pi == model2.piTrue>>> model.theta == model2.thetaTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> nb = nb.setThresholds([0.01, 10.00])>>> model3 = nb.fit(df)>>> result = model3.transform(test0).head()>>> result.prediction0.0>>> nb3 = NaiveBayes().setModelType("gaussian")>>> model4 = nb3.fit(df)>>> model4.getModelType()'gaussian'>>> model4.sigmaDenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)>>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")>>> model5 = nb5.fit(df)>>> model5.getModelType()'complement'>>> model5.thetaDenseMatrix(2, 2, [...], 1)>>> model5.sigmaDenseMatrix(0, 0, [...], ...)

670 Chapter 3. API Reference

Page 675: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getModelType() Gets the value of modelType or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSmoothing() Gets the value of smoothing or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLabelCol(value) Sets the value of labelCol.setModelType(value) Sets the value of modelType.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for Naive Bayes.setPredictionCol(value) Sets the value of predictionCol.

continues on next page

3.3. ML 671

Page 676: pyspark Documentation

pyspark Documentation, Release master

Table 230 – continued from previous pagesetProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSmoothing(value) Sets the value of smoothing.setThresholds(value) Sets the value of thresholds.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

featuresCollabelColmodelTypeparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColsmoothingthresholdsweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

672 Chapter 3. API Reference

Page 677: pyspark Documentation

pyspark Documentation, Release master

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFeaturesCol()Gets the value of featuresCol or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getModelType()Gets the value of modelType or its default value.

New in version 1.5.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSmoothing()Gets the value of smoothing or its default value.

3.3. ML 673

Page 678: pyspark Documentation

pyspark Documentation, Release master

New in version 1.5.0.

getThresholds()Gets the value of thresholds or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setModelType(value)Sets the value of modelType.

New in version 1.5.0.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", prob-abilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, model-Type="multinomial", thresholds=None, weightCol=None)

Sets params for Naive Bayes.

New in version 1.5.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

674 Chapter 3. API Reference

Page 679: pyspark Documentation

pyspark Documentation, Release master

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setSmoothing(value)Sets the value of smoothing.

New in version 1.5.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setWeightCol(value)Sets the value of weightCol.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

modelType = Param(parent='undefined', name='modelType', doc='The model type which is a string (case-sensitive). Supported options: multinomial (default), bernoulli and gaussian.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

smoothing = Param(parent='undefined', name='smoothing', doc='The smoothing parameter, should be >= 0, default is 1.0')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

NaiveBayesModel

class pyspark.ml.classification.NaiveBayesModel(java_model=None)Model fitted by NaiveBayes.

New in version 1.5.0.

3.3. ML 675

Page 680: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getModelType() Gets the value of modelType or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSmoothing() Gets the value of smoothing or its default value.getThresholds() Gets the value of thresholds or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.

continues on next page

676 Chapter 3. API Reference

Page 681: pyspark Documentation

pyspark Documentation, Release master

Table 232 – continued from previous pagetransform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

featuresCollabelColmodelTypenumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.pi log of class priors.predictionColprobabilityColrawPredictionColsigma variance of each feature.smoothingtheta log of class conditional probabilities.thresholdsweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

3.3. ML 677

Page 682: pyspark Documentation

pyspark Documentation, Release master

extra [dict, optional] extra param values

Returns

dict merged param map

getFeaturesCol()Gets the value of featuresCol or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getModelType()Gets the value of modelType or its default value.

New in version 1.5.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSmoothing()Gets the value of smoothing or its default value.

New in version 1.5.0.

getThresholds()Gets the value of thresholds or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

678 Chapter 3. API Reference

Page 683: pyspark Documentation

pyspark Documentation, Release master

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 679

Page 684: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

modelType = Param(parent='undefined', name='modelType', doc='The model type which is a string (case-sensitive). Supported options: multinomial (default), bernoulli and gaussian.')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

pilog of class priors.

New in version 2.0.0.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

sigmavariance of each feature.

New in version 3.0.0.

smoothing = Param(parent='undefined', name='smoothing', doc='The smoothing parameter, should be >= 0, default is 1.0')

thetalog of class conditional probabilities.

New in version 2.0.0.

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

MultilayerPerceptronClassifier

class pyspark.ml.classification.MultilayerPerceptronClassifier(*args,**kwargs)

Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layerhas softmax. Number of inputs has to be equal to the size of feature vectors. Number of outputs has to be equalto the total number of labels.

New in version 1.6.0.

680 Chapter 3. API Reference

Page 685: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (0.0, Vectors.dense([0.0, 0.0])),... (1.0, Vectors.dense([0.0, 1.0])),... (1.0, Vectors.dense([1.0, 0.0])),... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])>>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)>>> mlp.setMaxIter(100)MultilayerPerceptronClassifier...>>> mlp.getMaxIter()100>>> mlp.getBlockSize()128>>> mlp.setBlockSize(1)MultilayerPerceptronClassifier...>>> mlp.getBlockSize()1>>> model = mlp.fit(df)>>> model.setFeaturesCol("features")MultilayerPerceptronClassificationModel...>>> model.getMaxIter()100>>> model.getLayers()[2, 2, 2]>>> model.weights.size12>>> testDF = spark.createDataFrame([... (Vectors.dense([1.0, 0.0]),),... (Vectors.dense([0.0, 0.0]),)], ["features"])>>> model.predict(testDF.head().features)1.0>>> model.predictRaw(testDF.head().features)DenseVector([-16.208, 16.344])>>> model.predictProbability(testDF.head().features)DenseVector([0.0, 1.0])>>> model.transform(testDF).select("features", "prediction").show()+---------+----------+| features|prediction|+---------+----------+|[1.0,0.0]| 1.0||[0.0,0.0]| 0.0|+---------+----------+...>>> mlp_path = temp_path + "/mlp">>> mlp.save(mlp_path)>>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)>>> mlp2.getBlockSize()1>>> model_path = temp_path + "/mlp_model">>> model.save(model_path)>>> model2 = MultilayerPerceptronClassificationModel.load(model_path)>>> model.getLayers() == model2.getLayers()True>>> model.weights == model2.weightsTrue

(continues on next page)

3.3. ML 681

Page 686: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.transform(testDF).take(1) == model2.transform(testDF).take(1)True>>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))>>> model3 = mlp2.fit(df)>>> model3.weights != model2.weightsTrue>>> model3.getLayers() == model.getLayers()True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getBlockSize() Gets the value of blockSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getInitialWeights() Gets the value of initialWeights or its default value.getLabelCol() Gets the value of labelCol or its default value.getLayers() Gets the value of layers or its default value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.

continues on next page

682 Chapter 3. API Reference

Page 687: pyspark Documentation

pyspark Documentation, Release master

Table 234 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBlockSize(value) Sets the value of blockSize.setFeaturesCol(value) Sets the value of featuresCol.setInitialWeights(value) Sets the value of initialWeights.setLabelCol(value) Sets the value of labelCol.setLayers(value) Sets the value of layers.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-

belCol=”label”, predictionCol=”prediction”,maxIter=100, tol=1e-6, seed=None, layers=None,blockSize=128, stepSize=0.03, solver=”l-bfgs”,initialWeights=None, probabilityCol=”probability”,rawPredictionCol=”rawPrediction”): Sets paramsfor MultilayerPerceptronClassifier.

setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setSeed(value) Sets the value of seed.setSolver(value) Sets the value of solver.setStepSize(value) Sets the value of stepSize.setThresholds(value) Sets the value of thresholds.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.

Attributes

blockSizefeaturesColinitialWeightslabelCollayersmaxIterparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsolverstepSize

continues on next page

3.3. ML 683

Page 688: pyspark Documentation

pyspark Documentation, Release master

Table 235 – continued from previous pagethresholdstol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

684 Chapter 3. API Reference

Page 689: pyspark Documentation

pyspark Documentation, Release master

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getBlockSize()Gets the value of blockSize or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getInitialWeights()Gets the value of initialWeights or its default value.

New in version 2.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getLayers()Gets the value of layers or its default value.

New in version 1.6.0.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getSolver()Gets the value of solver or its default value.

getStepSize()Gets the value of stepSize or its default value.

getThresholds()Gets the value of thresholds or its default value.

getTol()Gets the value of tol or its default value.

3.3. ML 685

Page 690: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBlockSize(value)Sets the value of blockSize.

New in version 1.6.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setInitialWeights(value)Sets the value of initialWeights.

New in version 2.0.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLayers(value)Sets the value of layers.

New in version 1.6.0.

setMaxIter(value)Sets the value of maxIter.

setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, maxIter=100,tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver=”l-bfgs”, initialWeights=None,probabilityCol=”probability”, rawPredictionCol=”rawPrediction”): Sets params for MultilayerPercep-tronClassifier.

New in version 1.6.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

686 Chapter 3. API Reference

Page 691: pyspark Documentation

pyspark Documentation, Release master

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

setSolver(value)Sets the value of solver.

setStepSize(value)Sets the value of stepSize.

New in version 2.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setTol(value)Sets the value of tol.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

initialWeights = Param(parent='undefined', name='initialWeights', doc='The initial weights of the model.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

layers = Param(parent='undefined', name='layers', doc='Sizes of layers from input layer to output layer E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 neurons and output layer of 10 neurons.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: l-bfgs, gd.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

3.3. ML 687

Page 692: pyspark Documentation

pyspark Documentation, Release master

MultilayerPerceptronClassificationModel

class pyspark.ml.classification.MultilayerPerceptronClassificationModel(java_model=None)Model fitted by MultilayerPerceptronClassifier.

New in version 1.6.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBlockSize() Gets the value of blockSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getInitialWeights() Gets the value of initialWeights or its default value.getLabelCol() Gets the value of labelCol or its default value.getLayers() Gets the value of layers or its default value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).continues on next page

688 Chapter 3. API Reference

Page 693: pyspark Documentation

pyspark Documentation, Release master

Table 236 – continued from previous pagepredict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.summary() Gets summary (e.g.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

blockSizefeaturesColhasSummary Indicates whether a training summary exists for this

model instance.initialWeightslabelCollayersmaxIternumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColseedsolverstepSizethresholdstolweights the weights of layers.

3.3. ML 689

Page 694: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 3.1.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBlockSize()Gets the value of blockSize or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getInitialWeights()Gets the value of initialWeights or its default value.

New in version 2.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getLayers()Gets the value of layers or its default value.

New in version 1.6.0.

690 Chapter 3. API Reference

Page 695: pyspark Documentation

pyspark Documentation, Release master

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getSolver()Gets the value of solver or its default value.

getStepSize()Gets the value of stepSize or its default value.

getThresholds()Gets the value of thresholds or its default value.

getTol()Gets the value of tol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

3.3. ML 691

Page 696: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

summary()Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.

New in version 3.1.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

692 Chapter 3. API Reference

Page 697: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

initialWeights = Param(parent='undefined', name='initialWeights', doc='The initial weights of the model.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

layers = Param(parent='undefined', name='layers', doc='Sizes of layers from input layer to output layer E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 neurons and output layer of 10 neurons.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: l-bfgs, gd.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightsthe weights of layers.

New in version 2.0.0.

MultilayerPerceptronClassificationSummary

class pyspark.ml.classification.MultilayerPerceptronClassificationSummary(java_obj=None)Abstraction for MultilayerPerceptronClassifier Results for a given model.

New in version 3.1.0.

3.3. ML 693

Page 698: pyspark Documentation

pyspark Documentation, Release master

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

694 Chapter 3. API Reference

Page 699: pyspark Documentation

pyspark Documentation, Release master

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

3.3. ML 695

Page 700: pyspark Documentation

pyspark Documentation, Release master

MultilayerPerceptronClassificationTrainingSummary

class pyspark.ml.classification.MultilayerPerceptronClassificationTrainingSummary(java_obj=None)Abstraction for MultilayerPerceptronClassifier Training results.

New in version 3.1.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.precisionByLabel Returns precision for each label (category).predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

696 Chapter 3. API Reference

Page 701: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

3.3. ML 697

Page 702: pyspark Documentation

pyspark Documentation, Release master

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

OneVsRest

class pyspark.ml.classification.OneVsRest(*args, **kwargs)Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strat-egy. For a multiclass classification with k classes, train k models (one per class). Each example is scored againstall k models and the model with highest score is picked to label the example.

New in version 2.0.0.

Examples

>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> data_path = "data/mllib/sample_multiclass_classification_data.txt">>> df = spark.read.format("libsvm").load(data_path)>>> lr = LogisticRegression(regParam=0.01)>>> ovr = OneVsRest(classifier=lr)>>> ovr.getRawPredictionCol()'rawPrediction'>>> ovr.setPredictionCol("newPrediction")OneVsRest...>>> model = ovr.fit(df)>>> model.models[0].coefficientsDenseVector([0.5..., -1.0..., 3.4..., 4.2...])>>> model.models[1].coefficientsDenseVector([-2.1..., 3.1..., -2.6..., -2.3...])>>> model.models[2].coefficientsDenseVector([0.3..., -3.4..., 1.0..., -1.1...])>>> [x.intercept for x in model.models]

(continues on next page)

698 Chapter 3. API Reference

Page 703: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

[-2.7..., -2.5..., -1.3...]>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).→˓toDF()>>> model.transform(test0).head().newPrediction0.0>>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()>>> model.transform(test1).head().newPrediction2.0>>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).→˓toDF()>>> model.transform(test2).head().newPrediction0.0>>> model_path = temp_path + "/ovr_model">>> model.save(model_path)>>> model2 = OneVsRestModel.load(model_path)>>> model2.transform(test0).head().newPrediction0.0>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.transform(test2).columns['features', 'rawPrediction', 'newPrediction']

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getClassifier() Gets the value of classifier or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParallelism() Gets the value of parallelism or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.

continues on next page

3.3. ML 699

Page 704: pyspark Documentation

pyspark Documentation, Release master

Table 242 – continued from previous pagegetRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setClassifier(value) Sets the value of classifier.setFeaturesCol(value) Sets the value of featuresCol.setLabelCol(value) Sets the value of labelCol.setParallelism(value) Sets the value of parallelism.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, label-

Col=”label”, predictionCol=”prediction”, raw-PredictionCol=”rawPrediction”, classifier=None,weightCol=None, parallelism=1): Sets params forOneVsRest.

setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

classifierfeaturesCollabelColparallelismparams Returns all params ordered by name.predictionColrawPredictionColweightCol

700 Chapter 3. API Reference

Page 705: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This creates a deepcopy of the embedded paramMap, and copies the embedded and extra parameters over.

New in version 2.0.0.

Returns

OneVsRest Copy of this instance

Examples

extra [dict, optional] Extra parameters to copy to the new instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

3.3. ML 701

Page 706: pyspark Documentation

pyspark Documentation, Release master

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getClassifier()Gets the value of classifier or its default value.

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParallelism()Gets the value of parallelism or its default value.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setClassifier(value)Sets the value of classifier.

702 Chapter 3. API Reference

Page 707: pyspark Documentation

pyspark Documentation, Release master

New in version 2.0.0.

setFeaturesCol(value)Sets the value of featuresCol.

setLabelCol(value)Sets the value of labelCol.

setParallelism(value)Sets the value of parallelism.

setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, rawPrediction-Col=”rawPrediction”, classifier=None, weightCol=None, parallelism=1): Sets params for OneVsRest.

New in version 2.0.0.

setPredictionCol(value)Sets the value of predictionCol.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

setWeightCol(value)Sets the value of weightCol.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

classifier = Param(parent='undefined', name='classifier', doc='base binary classifier')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

parallelism = Param(parent='undefined', name='parallelism', doc='the number of threads to use when running parallel algorithms (>= 1).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

OneVsRestModel

class pyspark.ml.classification.OneVsRestModel(models)Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for eachclass. Each example is scored against all k models, and the model with the highest score is picked to label theexample.

New in version 2.0.0.

3.3. ML 703

Page 708: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getClassifier() Gets the value of classifier or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

704 Chapter 3. API Reference

Page 709: pyspark Documentation

pyspark Documentation, Release master

Attributes

classifierfeaturesCollabelColparams Returns all params ordered by name.predictionColrawPredictionColweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This creates a deepcopy of the embedded paramMap, and copies the embedded and extra parameters over.

New in version 2.0.0.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

OneVsRestModel Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getClassifier()Gets the value of classifier or its default value.

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getLabelCol()Gets the value of labelCol or its default value.

3.3. ML 705

Page 710: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

setPredictionCol(value)Sets the value of predictionCol.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

706 Chapter 3. API Reference

Page 711: pyspark Documentation

pyspark Documentation, Release master

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

classifier = Param(parent='undefined', name='classifier', doc='base binary classifier')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

FMClassifier

class pyspark.ml.classification.FMClassifier(*args, **kwargs)Factorization Machines learning algorithm for classification.

Solver supports:

• gd (normal mini-batch gradient descent)

• adamW (default)

New in version 3.0.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.classification import FMClassifier>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> fm = FMClassifier(factorSize=2)>>> fm.setSeed(11)FMClassifier...>>> model = fm.fit(df)>>> model.getMaxIter()100>>> test0 = spark.createDataFrame([... (Vectors.dense(-1.0),),... (Vectors.dense(0.5),),... (Vectors.dense(1.0),),... (Vectors.dense(2.0),)], ["features"])>>> model.predictRaw(test0.head().features)DenseVector([22.13..., -22.13...])>>> model.predictProbability(test0.head().features)DenseVector([1.0, 0.0])>>> model.transform(test0).select("features", "probability").show(10, False)+--------+------------------------------------------+

(continues on next page)

3.3. ML 707

Page 712: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

|features|probability |+--------+------------------------------------------+|[-1.0] |[0.9999999997574736,2.425264676902229E-10]||[0.5] |[0.47627851732981163,0.5237214826701884] ||[1.0] |[5.491554426243495E-4,0.9994508445573757] ||[2.0] |[2.005766663870645E-10,0.9999999997994233]|+--------+------------------------------------------+...>>> model.intercept-7.316665276826291>>> model.linearDenseVector([14.8232])>>> model.factorsDenseMatrix(1, 2, [0.0163, -0.0051], 1)>>> model_path = temp_path + "/fm_model">>> model.save(model_path)>>> model2 = FMClassificationModel.load(model_path)>>> model2.intercept-7.316665276826291>>> model2.linearDenseVector([14.8232])>>> model2.factorsDenseMatrix(1, 2, [0.0163, -0.0051], 1)>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.

continues on next page

708 Chapter 3. API Reference

Page 713: pyspark Documentation

pyspark Documentation, Release master

Table 246 – continued from previous pagegetInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFactorSize(value) Sets the value of factorSize.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setFitLinear(value) Sets the value of fitLinear.setInitStd(value) Sets the value of initStd.setLabelCol(value) Sets the value of labelCol.setMaxIter(value) Sets the value of maxIter.setMiniBatchFraction(value) Sets the value of miniBatchFraction.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets Params for FMClassifier.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setRegParam(value) Sets the value of regParam.setSeed(value) Sets the value of seed.setSolver(value) Sets the value of solver.setStepSize(value) Sets the value of stepSize.setThresholds(value) Sets the value of thresholds.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.

3.3. ML 709

Page 714: pyspark Documentation

pyspark Documentation, Release master

Attributes

factorSizefeaturesColfitInterceptfitLinearinitStdlabelColmaxIterminiBatchFractionparams Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParamseedsolverstepSizethresholdstolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

710 Chapter 3. API Reference

Page 715: pyspark Documentation

pyspark Documentation, Release master

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFactorSize()Gets the value of factorSize or its default value.

New in version 3.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getFitLinear()Gets the value of fitLinear or its default value.

New in version 3.0.0.

getInitStd()Gets the value of initStd or its default value.

New in version 3.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.

New in version 3.0.0.

3.3. ML 711

Page 716: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSeed()Gets the value of seed or its default value.

getSolver()Gets the value of solver or its default value.

getStepSize()Gets the value of stepSize or its default value.

getThresholds()Gets the value of thresholds or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFactorSize(value)Sets the value of factorSize.

712 Chapter 3. API Reference

Page 717: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setFitIntercept(value)Sets the value of fitIntercept.

New in version 3.0.0.

setFitLinear(value)Sets the value of fitLinear.

New in version 3.0.0.

setInitStd(value)Sets the value of initStd.

New in version 3.0.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setMaxIter(value)Sets the value of maxIter.

New in version 3.0.0.

setMiniBatchFraction(value)Sets the value of miniBatchFraction.

New in version 3.0.0.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", probabili-tyCol="probability", rawPredictionCol="rawPrediction", factorSize=8, fitIntercept=True,fitLinear=True, regParam=0.0, miniBatchFraction=1.0, initStd=0.01, maxIter=100, step-Size=1.0, tol=1e-6, solver="adamW", thresholds=None, seed=None)

Sets Params for FMClassifier.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setRegParam(value)Sets the value of regParam.

New in version 3.0.0.

3.3. ML 713

Page 718: pyspark Documentation

pyspark Documentation, Release master

setSeed(value)Sets the value of seed.

New in version 3.0.0.

setSolver(value)Sets the value of solver.

New in version 3.0.0.

setStepSize(value)Sets the value of stepSize.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

setTol(value)Sets the value of tol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')

initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

seed = Param(parent='undefined', name='seed', doc='random seed.')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

714 Chapter 3. API Reference

Page 719: pyspark Documentation

pyspark Documentation, Release master

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

FMClassificationModel

class pyspark.ml.classification.FMClassificationModel(java_model=None)Model fitted by FMClassifier.

New in version 3.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.getInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getThresholds() Gets the value of thresholds or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.

continues on next page

3.3. ML 715

Page 720: pyspark Documentation

pyspark Documentation, Release master

Table 248 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict the probability of each class given the fea-

tures.predictRaw(value) Raw prediction for each possible label.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setRawPredictionCol(value) Sets the value of rawPredictionCol.setThresholds(value) Sets the value of thresholds.summary() Gets summary (e.g.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

factorSizefactors Model factor term.featuresColfitInterceptfitLinearhasSummary Indicates whether a training summary exists for this

model instance.initStdintercept Model intercept.labelCollinear Model linear term.maxIterminiBatchFractionnumClasses Number of classes (values which the label can take).numFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColprobabilityColrawPredictionColregParam

continues on next page

716 Chapter 3. API Reference

Page 721: pyspark Documentation

pyspark Documentation, Release master

Table 249 – continued from previous pageseedsolverstepSizethresholdstolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 3.1.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getFactorSize()Gets the value of factorSize or its default value.

New in version 3.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

3.3. ML 717

Page 722: pyspark Documentation

pyspark Documentation, Release master

getFitIntercept()Gets the value of fitIntercept or its default value.

getFitLinear()Gets the value of fitLinear or its default value.

New in version 3.0.0.

getInitStd()Gets the value of initStd or its default value.

New in version 3.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSeed()Gets the value of seed or its default value.

getSolver()Gets the value of solver or its default value.

getStepSize()Gets the value of stepSize or its default value.

getThresholds()Gets the value of thresholds or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

718 Chapter 3. API Reference

Page 723: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictProbability(value)Predict the probability of each class given the features.

New in version 3.0.0.

predictRaw(value)Raw prediction for each possible label.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

New in version 3.0.0.

setThresholds(value)Sets the value of thresholds.

New in version 3.0.0.

summary()Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model trained on thetraining set. An exception is thrown if trainingSummary is None.

New in version 3.1.0.

3.3. ML 719

Page 724: pyspark Documentation

pyspark Documentation, Release master

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')

factorsModel factor term.

New in version 3.0.0.

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')

interceptModel intercept.

New in version 3.0.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

linearModel linear term.

New in version 3.0.0.

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')

numClassesNumber of classes (values which the label can take).

New in version 2.1.0.

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

720 Chapter 3. API Reference

Page 725: pyspark Documentation

pyspark Documentation, Release master

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

seed = Param(parent='undefined', name='seed', doc='random seed.')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

thresholds = Param(parent='undefined', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.")

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

FMClassificationSummary

class pyspark.ml.classification.FMClassificationSummary(java_obj=None)Abstraction for FMClassifier Results for a given model.

New in version 3.1.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.pr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.continues on next page

3.3. ML 721

Page 726: pyspark Documentation

pyspark Documentation, Release master

Table 251 – continued from previous pagepredictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

722 Chapter 3. API Reference

Page 727: pyspark Documentation

pyspark Documentation, Release master

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

3.3. ML 723

Page 728: pyspark Documentation

pyspark Documentation, Release master

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

FMClassificationTrainingSummary

class pyspark.ml.classification.FMClassificationTrainingSummary(java_obj=None)Abstraction for FMClassifier Training results.

New in version 3.1.0.

Methods

fMeasureByLabel([beta]) Returns f-measure for each label (category).weightedFMeasure([beta]) Returns weighted averaged f-measure.

724 Chapter 3. API Reference

Page 729: pyspark Documentation

pyspark Documentation, Release master

Attributes

accuracy Returns accuracy.areaUnderROC Computes the area under the receiver operating char-

acteristic (ROC) curve.fMeasureByThreshold Returns a dataframe with two fields (threshold, F-

Measure) curve with beta = 1.0.falsePositiveRateByLabel Returns false positive rate for each label (category).labelCol Field in “predictions” which gives the true label of

each instance.labels Returns the sequence of labels in ascending order.objectiveHistory Objective function (scaled loss + regularization) at

each iteration.pr Returns the precision-recall curve, which is a

Dataframe containing two fields recall, precisionwith (0.0, 1.0) prepended to it.

precisionByLabel Returns precision for each label (category).precisionByThreshold Returns a dataframe with two fields (threshold, pre-

cision) curve.predictionCol Field in “predictions” which gives the prediction of

each class.predictions Dataframe outputted by the model’s transform

method.recallByLabel Returns recall for each label (category).recallByThreshold Returns a dataframe with two fields (threshold, re-

call) curve.roc Returns the receiver operating characteristic (ROC)

curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) ap-pended to it.

scoreCol Field in “predictions” which gives the probability orraw prediction of each class as a vector.

totalIterations Number of training iterations until termination.truePositiveRateByLabel Returns true positive rate for each label (category).weightCol Field in “predictions” which gives the weight of each

instance as a vector.weightedFalsePositiveRate Returns weighted false positive rate.weightedPrecision Returns weighted averaged precision.weightedRecall Returns weighted averaged recall.weightedTruePositiveRate Returns weighted true positive rate.

Methods Documentation

fMeasureByLabel(beta=1.0)Returns f-measure for each label (category).

New in version 3.1.0.

weightedFMeasure(beta=1.0)Returns weighted averaged f-measure.

New in version 3.1.0.

3.3. ML 725

Page 730: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

accuracyReturns accuracy. (equals to the total number of correctly classified instances out of the total number ofinstances.)

New in version 3.1.0.

areaUnderROCComputes the area under the receiver operating characteristic (ROC) curve.

New in version 3.1.0.

fMeasureByThresholdReturns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.

New in version 3.1.0.

falsePositiveRateByLabelReturns false positive rate for each label (category).

New in version 3.1.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 3.1.0.

labelsReturns the sequence of labels in ascending order. This order matches the order used in metrics which arespecified as arrays over labels, e.g., truePositiveRateByLabel.

New in version 3.1.0.

Notes

In most cases, it will be values {0.0, 1.0, . . . , numClasses-1}, However, if the training set is missing alabel, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. It contains one more element, the initialstate, than number of iterations.

New in version 3.1.0.

prReturns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0,1.0) prepended to it.

New in version 3.1.0.

precisionByLabelReturns precision for each label (category).

New in version 3.1.0.

precisionByThresholdReturns a dataframe with two fields (threshold, precision) curve. Every possible probability obtained intransforming the dataset are used as thresholds used in calculating the precision.

New in version 3.1.0.

726 Chapter 3. API Reference

Page 731: pyspark Documentation

pyspark Documentation, Release master

predictionColField in “predictions” which gives the prediction of each class.

New in version 3.1.0.

predictionsDataframe outputted by the model’s transform method.

New in version 3.1.0.

recallByLabelReturns recall for each label (category).

New in version 3.1.0.

recallByThresholdReturns a dataframe with two fields (threshold, recall) curve. Every possible probability obtained in trans-forming the dataset are used as thresholds used in calculating the recall.

New in version 3.1.0.

rocReturns the receiver operating characteristic (ROC) curve, which is a Dataframe having two fields (FPR,TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.

New in version 3.1.0.

Notes

Wikipedia reference

scoreColField in “predictions” which gives the probability or raw prediction of each class as a vector.

New in version 3.1.0.

totalIterationsNumber of training iterations until termination.

New in version 3.1.0.

truePositiveRateByLabelReturns true positive rate for each label (category).

New in version 3.1.0.

weightColField in “predictions” which gives the weight of each instance as a vector.

New in version 3.1.0.

weightedFalsePositiveRateReturns weighted false positive rate.

New in version 3.1.0.

weightedPrecisionReturns weighted averaged precision.

New in version 3.1.0.

weightedRecallReturns weighted averaged recall. (equals to precision, recall and f-measure)

New in version 3.1.0.

3.3. ML 727

Page 732: pyspark Documentation

pyspark Documentation, Release master

weightedTruePositiveRateReturns weighted true positive rate. (equals to precision, recall and f-measure)

New in version 3.1.0.

3.3.5 Clustering

BisectingKMeans(*args, **kwargs) A bisecting k-means algorithm based on the paper“A comparison of document clustering techniques” bySteinbach, Karypis, and Kumar, with modification to fitSpark.

BisectingKMeansModel([java_model]) Model fitted by BisectingKMeans.BisectingKMeansSummary([java_obj]) Bisecting KMeans clustering results for a given model.KMeans(*args, **kwargs) K-means clustering with a k-means++ like initialization

mode (the k-means|| algorithm by Bahmani et al).KMeansModel([java_model]) Model fitted by KMeans.KMeansSummary([java_obj]) Summary of KMeans.GaussianMixture(*args, **kwargs) GaussianMixture clustering.GaussianMixtureModel([java_model]) Model fitted by GaussianMixture.GaussianMixtureSummary([java_obj]) Gaussian mixture clustering results for a given model.LDA(*args, **kwargs) Latent Dirichlet Allocation (LDA), a topic model de-

signed for text documents.LDAModel([java_model]) Latent Dirichlet Allocation (LDA) model.LocalLDAModel([java_model]) Local (non-distributed) model fitted by LDA.DistributedLDAModel([java_model]) Distributed model fitted by LDA.PowerIterationClustering(*args, **kwargs) Power Iteration Clustering (PIC), a scalable graph clus-

tering algorithm developed by Lin and Cohen.From theabstract: PIC finds a very low-dimensional embeddingof a dataset using truncated power iteration on a normal-ized pair-wise similarity matrix of the data..

BisectingKMeans

class pyspark.ml.clustering.BisectingKMeans(*args, **kwargs)A bisecting k-means algorithm based on the paper “A comparison of document clustering techniques” by Stein-bach, Karypis, and Kumar, with modification to fit Spark. The algorithm starts from a single cluster that containsall points. Iteratively it finds divisible clusters on the bottom level and bisects each of them using k-means, untilthere are k leaf clusters in total or no leaf clusters are divisible. The bisecting steps of clusters on the same levelare grouped together to increase parallelism. If bisecting all divisible clusters on the bottom level would resultmore than k leaf clusters, larger clusters get higher priority.

New in version 2.0.0.

728 Chapter 3. API Reference

Page 733: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]>>> df = spark.createDataFrame(data, ["features", "weighCol"])>>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0)>>> bkm.setMaxIter(10)BisectingKMeans...>>> bkm.getMaxIter()10>>> bkm.clear(bkm.maxIter)>>> bkm.setSeed(1)BisectingKMeans...>>> bkm.setWeightCol("weighCol")BisectingKMeans...>>> bkm.getSeed()1>>> bkm.clear(bkm.seed)>>> model = bkm.fit(df)>>> model.getMaxIter()20>>> model.setPredictionCol("newPrediction")BisectingKMeansModel...>>> model.predict(df.head().features)0>>> centers = model.clusterCenters()>>> len(centers)2>>> model.computeCost(df)2.0>>> model.hasSummaryTrue>>> summary = model.summary>>> summary.k2>>> summary.clusterSizes[2, 2]>>> summary.trainingCost4.000...>>> transformed = model.transform(df).select("features", "newPrediction")>>> rows = transformed.collect()>>> rows[0].newPrediction == rows[1].newPredictionTrue>>> rows[2].newPrediction == rows[3].newPredictionTrue>>> bkm_path = temp_path + "/bkm">>> bkm.save(bkm_path)>>> bkm2 = BisectingKMeans.load(bkm_path)>>> bkm2.getK()2>>> bkm2.getDistanceMeasure()'euclidean'>>> model_path = temp_path + "/bkm_model">>> model.save(model_path)>>> model2 = BisectingKMeansModel.load(model_path)>>> model2.hasSummary

(continues on next page)

3.3. ML 729

Page 734: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

False>>> model.clusterCenters()[0] == model2.clusterCenters()[0]array([ True, True], dtype=bool)>>> model.clusterCenters()[1] == model2.clusterCenters()[1]array([ True, True], dtype=bool)>>> model.transform(df).take(1) == model2.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getDistanceMeasure() Gets the value of distanceMeasure or its defaultvalue.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getMaxIter() Gets the value of maxIter or its default value.getMinDivisibleClusterSize() Gets the value of minDivisibleClusterSize or its de-

fault value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).continues on next page

730 Chapter 3. API Reference

Page 735: pyspark Documentation

pyspark Documentation, Release master

Table 255 – continued from previous pageread() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDistanceMeasure(value) Sets the value of distanceMeasure.setFeaturesCol(value) Sets the value of featuresCol.setK(value) Sets the value of k.setMaxIter(value) Sets the value of maxIter.setMinDivisibleClusterSize(value) Sets the value of minDivisibleClusterSize.setParams(self, \*[, featuresCol, . . . ]) Sets params for BisectingKMeans.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

distanceMeasurefeaturesColkmaxIterminDivisibleClusterSizeparams Returns all params ordered by name.predictionColseedweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

3.3. ML 731

Page 736: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getDistanceMeasure()Gets the value of distanceMeasure or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getMinDivisibleClusterSize()Gets the value of minDivisibleClusterSize or its default value.

New in version 2.0.0.

732 Chapter 3. API Reference

Page 737: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDistanceMeasure(value)Sets the value of distanceMeasure.

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 2.0.0.

setK(value)Sets the value of k.

New in version 2.0.0.

setMaxIter(value)Sets the value of maxIter.

New in version 2.0.0.

setMinDivisibleClusterSize(value)Sets the value of minDivisibleClusterSize.

New in version 2.0.0.

3.3. ML 733

Page 738: pyspark Documentation

pyspark Documentation, Release master

setParams(self, \*, featuresCol="features", predictionCol="prediction", maxIter=20, seed=None,k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", weightCol=None)

Sets params for BisectingKMeans.

New in version 2.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 2.0.0.

setSeed(value)Sets the value of seed.

New in version 2.0.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

k = Param(parent='undefined', name='k', doc='The desired number of leaf clusters. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

minDivisibleClusterSize = Param(parent='undefined', name='minDivisibleClusterSize', doc='The minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

BisectingKMeansModel

class pyspark.ml.clustering.BisectingKMeansModel(java_model=None)Model fitted by BisectingKMeans.

New in version 2.0.0.

734 Chapter 3. API Reference

Page 739: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

clusterCenters() Get the cluster centers, represented as a list ofNumPy arrays.

computeCost(dataset) Computes the sum of squared distances between theinput points and their corresponding cluster centers.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getDistanceMeasure() Gets the value of distanceMeasure or its defaultvalue.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getMaxIter() Gets the value of maxIter or its default value.getMinDivisibleClusterSize() Gets the value of minDivisibleClusterSize or its de-

fault value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.continues on next page

3.3. ML 735

Page 740: pyspark Documentation

pyspark Documentation, Release master

Table 257 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.

Attributes

distanceMeasurefeaturesColhasSummary Indicates whether a training summary exists for this

model instance.kmaxIterminDivisibleClusterSizeparams Returns all params ordered by name.predictionColseedsummary Gets summary (e.g.weightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

clusterCenters()Get the cluster centers, represented as a list of NumPy arrays.

New in version 2.0.0.

computeCost(dataset)Computes the sum of squared distances between the input points and their corresponding cluster centers.

Deprecated since version 3.0.0: It will be removed in future versions. Use ClusteringEvaluatorinstead. You can also get the cost on the training dataset in the summary.

New in version 2.0.0.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

736 Chapter 3. API Reference

Page 741: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getDistanceMeasure()Gets the value of distanceMeasure or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getMinDivisibleClusterSize()Gets the value of minDivisibleClusterSize or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 737

Page 742: pyspark Documentation

pyspark Documentation, Release master

predict(value)Predict label for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

k = Param(parent='undefined', name='k', doc='The desired number of leaf clusters. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

minDivisibleClusterSize = Param(parent='undefined', name='minDivisibleClusterSize', doc='The minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

738 Chapter 3. API Reference

Page 743: pyspark Documentation

pyspark Documentation, Release master

summaryGets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An excep-tion is thrown if no summary exists.

New in version 2.1.0.

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

BisectingKMeansSummary

class pyspark.ml.clustering.BisectingKMeansSummary(java_obj=None)Bisecting KMeans clustering results for a given model.

New in version 2.1.0.

Attributes

cluster DataFrame of predicted cluster centers for eachtraining data point.

clusterSizes Size of (number of data points in) each cluster.featuresCol Name for column of features in predictions.k The number of clusters the model was trained with.numIter Number of iterations.predictionCol Name for column of predicted clusters in predic-

tions.predictions DataFrame produced by the model’s transform

method.trainingCost Sum of squared distances to the nearest centroid for

all points in the training dataset.

Attributes Documentation

clusterDataFrame of predicted cluster centers for each training data point.

New in version 2.1.0.

clusterSizesSize of (number of data points in) each cluster.

New in version 2.1.0.

featuresColName for column of features in predictions.

New in version 2.1.0.

kThe number of clusters the model was trained with.

New in version 2.1.0.

numIterNumber of iterations.

New in version 2.4.0.

3.3. ML 739

Page 744: pyspark Documentation

pyspark Documentation, Release master

predictionColName for column of predicted clusters in predictions.

New in version 2.1.0.

predictionsDataFrame produced by the model’s transform method.

New in version 2.1.0.

trainingCostSum of squared distances to the nearest centroid for all points in the training dataset. This is equivalent tosklearn’s inertia.

New in version 3.0.0.

KMeans

class pyspark.ml.clustering.KMeans(*args, **kwargs)K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al).

New in version 1.5.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> data = [(Vectors.dense([0.0, 0.0]), 2.0), (Vectors.dense([1.0, 1.0]), 2.0),... (Vectors.dense([9.0, 8.0]), 2.0), (Vectors.dense([8.0, 9.0]), 2.0)]>>> df = spark.createDataFrame(data, ["features", "weighCol"])>>> kmeans = KMeans(k=2)>>> kmeans.setSeed(1)KMeans...>>> kmeans.setWeightCol("weighCol")KMeans...>>> kmeans.setMaxIter(10)KMeans...>>> kmeans.getMaxIter()10>>> kmeans.clear(kmeans.maxIter)>>> model = kmeans.fit(df)>>> model.getDistanceMeasure()'euclidean'>>> model.setPredictionCol("newPrediction")KMeansModel...>>> model.predict(df.head().features)0>>> centers = model.clusterCenters()>>> len(centers)2>>> transformed = model.transform(df).select("features", "newPrediction")>>> rows = transformed.collect()>>> rows[0].newPrediction == rows[1].newPredictionTrue>>> rows[2].newPrediction == rows[3].newPredictionTrue>>> model.hasSummaryTrue

(continues on next page)

740 Chapter 3. API Reference

Page 745: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> summary = model.summary>>> summary.k2>>> summary.clusterSizes[2, 2]>>> summary.trainingCost4.0>>> kmeans_path = temp_path + "/kmeans">>> kmeans.save(kmeans_path)>>> kmeans2 = KMeans.load(kmeans_path)>>> kmeans2.getK()2>>> model_path = temp_path + "/kmeans_model">>> model.save(model_path)>>> model2 = KMeansModel.load(model_path)>>> model2.hasSummaryFalse>>> model.clusterCenters()[0] == model2.clusterCenters()[0]array([ True, True], dtype=bool)>>> model.clusterCenters()[1] == model2.clusterCenters()[1]array([ True, True], dtype=bool)>>> model.transform(df).take(1) == model2.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getDistanceMeasure() Gets the value of distanceMeasure or its defaultvalue.

getFeaturesCol() Gets the value of featuresCol or its default value.getInitMode() Gets the value of initModegetInitSteps() Gets the value of initStepsgetK() Gets the value of k

continues on next page

3.3. ML 741

Page 746: pyspark Documentation

pyspark Documentation, Release master

Table 260 – continued from previous pagegetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDistanceMeasure(value) Sets the value of distanceMeasure.setFeaturesCol(value) Sets the value of featuresCol.setInitMode(value) Sets the value of initMode.setInitSteps(value) Sets the value of initSteps.setK(value) Sets the value of k.setMaxIter(value) Sets the value of maxIter.setParams(self, \*[, featuresCol, . . . ]) Sets params for KMeans.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

distanceMeasurefeaturesColinitModeinitStepskmaxIterparams Returns all params ordered by name.predictionColseedtolweightCol

742 Chapter 3. API Reference

Page 747: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

3.3. ML 743

Page 748: pyspark Documentation

pyspark Documentation, Release master

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getDistanceMeasure()Gets the value of distanceMeasure or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getInitMode()Gets the value of initMode

New in version 1.5.0.

getInitSteps()Gets the value of initSteps

New in version 1.5.0.

getK()Gets the value of k

New in version 1.5.0.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

744 Chapter 3. API Reference

Page 749: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDistanceMeasure(value)Sets the value of distanceMeasure.

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 1.5.0.

setInitMode(value)Sets the value of initMode.

New in version 1.5.0.

setInitSteps(value)Sets the value of initSteps.

New in version 1.5.0.

setK(value)Sets the value of k.

New in version 1.5.0.

setMaxIter(value)Sets the value of maxIter.

New in version 1.5.0.

setParams(self, \*, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||",initSteps=2, tol=1e-4, maxIter=20, seed=None, distanceMeasure="euclidean", weight-Col=None)

Sets params for KMeans.

New in version 1.5.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 1.5.0.

setSeed(value)Sets the value of seed.

New in version 1.5.0.

setTol(value)Sets the value of tol.

New in version 1.5.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

3.3. ML 745

Page 750: pyspark Documentation

pyspark Documentation, Release master

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

initMode = Param(parent='undefined', name='initMode', doc='The initialization algorithm. This can be either "random" to choose random points as initial cluster centers, or "k-means||" to use a parallel variant of k-means++')

initSteps = Param(parent='undefined', name='initSteps', doc='The number of steps for k-means|| initialization mode. Must be > 0.')

k = Param(parent='undefined', name='k', doc='The number of clusters to create. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

KMeansModel

class pyspark.ml.clustering.KMeansModel(java_model=None)Model fitted by KMeans.

New in version 1.5.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

clusterCenters() Get the cluster centers, represented as a list ofNumPy arrays.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

continues on next page

746 Chapter 3. API Reference

Page 751: pyspark Documentation

pyspark Documentation, Release master

Table 262 – continued from previous pagegetDistanceMeasure() Gets the value of distanceMeasure or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getInitMode() Gets the value of initModegetInitSteps() Gets the value of initStepsgetK() Gets the value of kgetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an GeneralMLWriter instance for this ML

instance.

Attributes

distanceMeasurefeaturesColhasSummary Indicates whether a training summary exists for this

model instance.initModeinitStepskmaxIterparams Returns all params ordered by name.predictionColseedsummary Gets summary (e.g.tol

continues on next page

3.3. ML 747

Page 752: pyspark Documentation

pyspark Documentation, Release master

Table 263 – continued from previous pageweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

clusterCenters()Get the cluster centers, represented as a list of NumPy arrays.

New in version 1.5.0.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getDistanceMeasure()Gets the value of distanceMeasure or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getInitMode()Gets the value of initMode

New in version 1.5.0.

getInitSteps()Gets the value of initSteps

New in version 1.5.0.

748 Chapter 3. API Reference

Page 753: pyspark Documentation

pyspark Documentation, Release master

getK()Gets the value of k

New in version 1.5.0.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

3.3. ML 749

Page 754: pyspark Documentation

pyspark Documentation, Release master

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an GeneralMLWriter instance for this ML instance.

Attributes Documentation

distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="the distance measure. Supported options: 'euclidean' and 'cosine'.")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

initMode = Param(parent='undefined', name='initMode', doc='The initialization algorithm. This can be either "random" to choose random points as initial cluster centers, or "k-means||" to use a parallel variant of k-means++')

initSteps = Param(parent='undefined', name='initSteps', doc='The number of steps for k-means|| initialization mode. Must be > 0.')

k = Param(parent='undefined', name='k', doc='The number of clusters to create. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

summaryGets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An excep-tion is thrown if no summary exists.

New in version 2.1.0.

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

750 Chapter 3. API Reference

Page 755: pyspark Documentation

pyspark Documentation, Release master

KMeansSummary

class pyspark.ml.clustering.KMeansSummary(java_obj=None)Summary of KMeans.

New in version 2.1.0.

Attributes

cluster DataFrame of predicted cluster centers for eachtraining data point.

clusterSizes Size of (number of data points in) each cluster.featuresCol Name for column of features in predictions.k The number of clusters the model was trained with.numIter Number of iterations.predictionCol Name for column of predicted clusters in predic-

tions.predictions DataFrame produced by the model’s transform

method.trainingCost K-means cost (sum of squared distances to the near-

est centroid for all points in the training dataset).

Attributes Documentation

clusterDataFrame of predicted cluster centers for each training data point.

New in version 2.1.0.

clusterSizesSize of (number of data points in) each cluster.

New in version 2.1.0.

featuresColName for column of features in predictions.

New in version 2.1.0.

kThe number of clusters the model was trained with.

New in version 2.1.0.

numIterNumber of iterations.

New in version 2.4.0.

predictionColName for column of predicted clusters in predictions.

New in version 2.1.0.

predictionsDataFrame produced by the model’s transform method.

New in version 2.1.0.

3.3. ML 751

Page 756: pyspark Documentation

pyspark Documentation, Release master

trainingCostK-means cost (sum of squared distances to the nearest centroid for all points in the training dataset). Thisis equivalent to sklearn’s inertia.

New in version 2.4.0.

GaussianMixture

class pyspark.ml.clustering.GaussianMixture(*args, **kwargs)GaussianMixture clustering. This class performs expectation maximization for multivariate Gaussian MixtureModels (GMMs). A GMM represents a composite distribution of independent Gaussian distributions withassociated “mixing” weights specifying each’s contribution to the composite.

Given a set of sample points, this class will maximize the log-likelihood for a mixture of k Gaussians, iteratinguntil the log-likelihood changes by less than convergenceTol, or until it has reached the max number of iterations.While this process is generally guaranteed to converge, it is not guaranteed to find a global optimum.

New in version 2.0.0.

Notes

For high-dimensional data (with many features), this algorithm may perform poorly. This is due to high-dimensional data (a) making it difficult to cluster at all (based on statistical/theoretical arguments) and (b)numerical issues with Gaussian distributions.

Examples

>>> from pyspark.ml.linalg import Vectors

>>> data = [(Vectors.dense([-0.1, -0.05 ]),),... (Vectors.dense([-0.01, -0.1]),),... (Vectors.dense([0.9, 0.8]),),... (Vectors.dense([0.75, 0.935]),),... (Vectors.dense([-0.83, -0.68]),),... (Vectors.dense([-0.91, -0.76]),)]>>> df = spark.createDataFrame(data, ["features"])>>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)>>> gm.getMaxIter()100>>> gm.setMaxIter(30)GaussianMixture...>>> gm.getMaxIter()30>>> model = gm.fit(df)>>> model.getAggregationDepth()2>>> model.getFeaturesCol()'features'>>> model.setPredictionCol("newPrediction")GaussianMixtureModel...>>> model.predict(df.head().features)2>>> model.predictProbability(df.head().features)DenseVector([0.0, 0.0, 1.0])

(continues on next page)

752 Chapter 3. API Reference

Page 757: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.hasSummaryTrue>>> summary = model.summary>>> summary.k3>>> summary.clusterSizes[2, 2, 2]>>> summary.logLikelihood65.02945...>>> weights = model.weights>>> len(weights)3>>> gaussians = model.gaussians>>> len(gaussians)3>>> gaussians[0].meanDenseVector([0.825, 0.8675])>>> gaussians[0].covDenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], 0)>>> gaussians[1].meanDenseVector([-0.87, -0.72])>>> gaussians[1].covDenseMatrix(2, 2, [0.0016, 0.0016, 0.0016, 0.0016], 0)>>> gaussians[2].meanDenseVector([-0.055, -0.075])>>> gaussians[2].covDenseMatrix(2, 2, [0.002, -0.0011, -0.0011, 0.0006], 0)>>> model.gaussiansDF.select("mean").head()Row(mean=DenseVector([0.825, 0.8675]))>>> model.gaussiansDF.select("cov").head()Row(cov=DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], False))>>> transformed = model.transform(df).select("features", "newPrediction")>>> rows = transformed.collect()>>> rows[4].newPrediction == rows[5].newPredictionTrue>>> rows[2].newPrediction == rows[3].newPredictionTrue>>> gmm_path = temp_path + "/gmm">>> gm.save(gmm_path)>>> gm2 = GaussianMixture.load(gmm_path)>>> gm2.getK()3>>> model_path = temp_path + "/gmm_model">>> model.save(model_path)>>> model2 = GaussianMixtureModel.load(model_path)>>> model2.hasSummaryFalse>>> model2.weights == model.weightsTrue>>> model2.gaussians[0].mean == model.gaussians[0].meanTrue>>> model2.gaussians[0].cov == model.gaussians[0].covTrue>>> model2.gaussians[1].mean == model.gaussians[1].meanTrue>>> model2.gaussians[1].cov == model.gaussians[1].covTrue

(continues on next page)

3.3. ML 753

Page 758: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model2.gaussians[2].mean == model.gaussians[2].meanTrue>>> model2.gaussians[2].cov == model.gaussians[2].covTrue>>> model2.gaussiansDF.select("mean").head()Row(mean=DenseVector([0.825, 0.8675]))>>> model2.gaussiansDF.select("cov").head()Row(cov=DenseMatrix(2, 2, [0.0056, -0.0051, -0.0051, 0.0046], False))>>> model.transform(df).take(1) == model2.transform(df).take(1)True>>> gm2.setWeightCol("weight")GaussianMixture...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of kgetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.continues on next page

754 Chapter 3. API Reference

Page 759: pyspark Documentation

pyspark Documentation, Release master

Table 265 – continued from previous pageisDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setFeaturesCol(value) Sets the value of featuresCol.setK(value) Sets the value of k.setMaxIter(value) Sets the value of maxIter.setParams(self, \*[, featuresCol, . . . ]) Sets params for GaussianMixture.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setSeed(value) Sets the value of seed.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthfeaturesColkmaxIterparams Returns all params ordered by name.predictionColprobabilityColseedtolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

3.3. ML 755

Page 760: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k

New in version 2.0.0.

756 Chapter 3. API Reference

Page 761: pyspark Documentation

pyspark Documentation, Release master

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getSeed()Gets the value of seed or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAggregationDepth(value)Sets the value of aggregationDepth.

New in version 3.0.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 2.0.0.

setK(value)Sets the value of k.

New in version 2.0.0.

3.3. ML 757

Page 762: pyspark Documentation

pyspark Documentation, Release master

setMaxIter(value)Sets the value of maxIter.

New in version 2.0.0.

setParams(self, \*, featuresCol="features", predictionCol="prediction", k=2, probability-Col="probability", tol=0.01, maxIter=100, seed=None, aggregationDepth=2, weight-Col=None)

Sets params for GaussianMixture.

New in version 2.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 2.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 2.0.0.

setSeed(value)Sets the value of seed.

New in version 2.0.0.

setTol(value)Sets the value of tol.

New in version 2.0.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

k = Param(parent='undefined', name='k', doc='Number of independent Gaussians in the mixture model. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

758 Chapter 3. API Reference

Page 763: pyspark Documentation

pyspark Documentation, Release master

GaussianMixtureModel

class pyspark.ml.clustering.GaussianMixtureModel(java_model=None)Model fitted by GaussianMixture.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of kgetMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.getSeed() Gets the value of seed or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictProbability(value) Predict probability for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.

continues on next page

3.3. ML 759

Page 764: pyspark Documentation

pyspark Documentation, Release master

Table 267 – continued from previous pagesetFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthfeaturesColgaussians Array of MultivariateGaussian where gaus-

sians[i] represents the Multivariate Gaussian (Nor-mal) Distribution for Gaussian i

gaussiansDF Retrieve Gaussian distributions as a DataFrame.hasSummary Indicates whether a training summary exists for this

model instance.kmaxIterparams Returns all params ordered by name.predictionColprobabilityColseedsummary Gets summary (e.g.tolweightColweights Weight for each Gaussian distribution in the mixture.

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

760 Chapter 3. API Reference

Page 765: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getSeed()Gets the value of seed or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 761

Page 766: pyspark Documentation

pyspark Documentation, Release master

predict(value)Predict label for the given features.

New in version 3.0.0.

predictProbability(value)Predict probability for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

gaussiansArray of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal)Distribution for Gaussian i

New in version 3.0.0.

762 Chapter 3. API Reference

Page 767: pyspark Documentation

pyspark Documentation, Release master

gaussiansDFRetrieve Gaussian distributions as a DataFrame. Each row represents a Gaussian Distribution. TheDataFrame has two columns: mean (Vector) and cov (Matrix).

New in version 2.0.0.

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

k = Param(parent='undefined', name='k', doc='Number of independent Gaussians in the mixture model. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

summaryGets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An excep-tion is thrown if no summary exists.

New in version 2.1.0.

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

weightsWeight for each Gaussian distribution in the mixture. This is a multinomial probability distribution overthe k Gaussians, where weights[i] is the weight for Gaussian i, and weights sum to 1.

New in version 2.0.0.

GaussianMixtureSummary

class pyspark.ml.clustering.GaussianMixtureSummary(java_obj=None)Gaussian mixture clustering results for a given model.

New in version 2.1.0.

Attributes

cluster DataFrame of predicted cluster centers for eachtraining data point.

clusterSizes Size of (number of data points in) each cluster.featuresCol Name for column of features in predictions.k The number of clusters the model was trained with.logLikelihood Total log-likelihood for this model on the given data.numIter Number of iterations.predictionCol Name for column of predicted clusters in predic-

tions.continues on next page

3.3. ML 763

Page 768: pyspark Documentation

pyspark Documentation, Release master

Table 269 – continued from previous pagepredictions DataFrame produced by the model’s transform

method.probability DataFrame of probabilities of each cluster for each

training data point.probabilityCol Name for column of predicted probability of each

cluster in predictions.

Attributes Documentation

clusterDataFrame of predicted cluster centers for each training data point.

New in version 2.1.0.

clusterSizesSize of (number of data points in) each cluster.

New in version 2.1.0.

featuresColName for column of features in predictions.

New in version 2.1.0.

kThe number of clusters the model was trained with.

New in version 2.1.0.

logLikelihoodTotal log-likelihood for this model on the given data.

New in version 2.2.0.

numIterNumber of iterations.

New in version 2.4.0.

predictionColName for column of predicted clusters in predictions.

New in version 2.1.0.

predictionsDataFrame produced by the model’s transform method.

New in version 2.1.0.

probabilityDataFrame of probabilities of each cluster for each training data point.

New in version 2.1.0.

probabilityColName for column of predicted probability of each cluster in predictions.

New in version 2.1.0.

764 Chapter 3. API Reference

Page 769: pyspark Documentation

pyspark Documentation, Release master

LDA

class pyspark.ml.clustering.LDA(*args, **kwargs)Latent Dirichlet Allocation (LDA), a topic model designed for text documents.

Terminology:

• “term” = “word”: an element of the vocabulary

• “token”: instance of a term appearing in a document

• “topic”: multinomial distribution over terms representing some concept

• “document”: one piece of text, corresponding to one row in the input data

Original LDA paper (journal version): Blei, Ng, and Jordan. “Latent Dirichlet Allocation.” JMLR, 2003.

Input data (featuresCol): LDA is given a collection of documents as input data, via the featuresCol parameter.Each document is specified as a Vector of length vocabSize, where each entry is the count for the correspond-ing term (word) in the document. Feature transformers such as pyspark.ml.feature.Tokenizer andpyspark.ml.feature.CountVectorizer can be useful for converting text to word count vectors.

New in version 2.0.0.

Examples

>>> from pyspark.ml.linalg import Vectors, SparseVector>>> from pyspark.ml.clustering import LDA>>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])],... [2, SparseVector(2, {0: 1.0})],], ["id", "features"])>>> lda = LDA(k=2, seed=1, optimizer="em")>>> lda.setMaxIter(10)LDA...>>> lda.getMaxIter()10>>> lda.clear(lda.maxIter)>>> model = lda.fit(df)>>> model.setSeed(1)DistributedLDAModel...>>> model.getTopicDistributionCol()'topicDistribution'>>> model.isDistributed()True>>> localModel = model.toLocal()>>> localModel.isDistributed()False>>> model.vocabSize()2>>> model.describeTopics().show()+-----+-----------+--------------------+|topic|termIndices| termWeights|+-----+-----------+--------------------+| 0| [1, 0]|[0.50401530077160...|| 1| [0, 1]|[0.50401530077160...|+-----+-----------+--------------------+...>>> model.topicsMatrix()DenseMatrix(2, 2, [0.496, 0.504, 0.504, 0.496], 0)

(continues on next page)

3.3. ML 765

Page 770: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> lda_path = temp_path + "/lda">>> lda.save(lda_path)>>> sameLDA = LDA.load(lda_path)>>> distributed_model_path = temp_path + "/lda_distributed_model">>> model.save(distributed_model_path)>>> sameModel = DistributedLDAModel.load(distributed_model_path)>>> local_model_path = temp_path + "/lda_local_model">>> localModel.save(local_model_path)>>> sameLocalModel = LocalLDAModel.load(local_model_path)>>> model.transform(df).take(1) == sameLocalModel.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.

getDocConcentration() Gets the value of docConcentration or its de-fault value.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its

default value.getLearningDecay() Gets the value of learningDecay or its default

value.getLearningOffset() Gets the value of learningOffset or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOptimizeDocConcentration() Gets the value of

optimizeDocConcentration or its defaultvalue.

getOptimizer() Gets the value of optimizer or its default value.continues on next page

766 Chapter 3. API Reference

Page 771: pyspark Documentation

pyspark Documentation, Release master

Table 270 – continued from previous pagegetOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getTopicConcentration() Gets the value of topicConcentration or its

default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its

default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCheckpointInterval(value) Sets the value of checkpointInterval.setDocConcentration(value) Sets the value of docConcentration.setFeaturesCol(value) Sets the value of featuresCol.setK(value) Sets the value of k.setKeepLastCheckpoint(value) Sets the value of keepLastCheckpoint.setLearningDecay(value) Sets the value of learningDecay .setLearningOffset(value) Sets the value of learningOffset.setMaxIter(value) Sets the value of maxIter.setOptimizeDocConcentration(value) Sets the value of

optimizeDocConcentration.setOptimizer(value) Sets the value of optimizer.setParams(self, \*[, featuresCol, maxIter, . . . ]) Sets params for LDA.setSeed(value) Sets the value of seed.setSubsamplingRate(value) Sets the value of subsamplingRate.setTopicConcentration(value) Sets the value of topicConcentration.setTopicDistributionCol(value) Sets the value of topicDistributionCol.write() Returns an MLWriter instance for this ML instance.

Attributes

checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffset

continues on next page

3.3. ML 767

Page 772: pyspark Documentation

pyspark Documentation, Release master

Table 271 – continued from previous pagemaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.seedsubsamplingRatetopicConcentrationtopicDistributionCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

768 Chapter 3. API Reference

Page 773: pyspark Documentation

pyspark Documentation, Release master

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getDocConcentration()Gets the value of docConcentration or its default value.

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k or its default value.

New in version 2.0.0.

getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.

New in version 2.0.0.

getLearningDecay()Gets the value of learningDecay or its default value.

New in version 2.0.0.

getLearningOffset()Gets the value of learningOffset or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.

New in version 2.0.0.

getOptimizer()Gets the value of optimizer or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

3.3. ML 769

Page 774: pyspark Documentation

pyspark Documentation, Release master

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 2.0.0.

getTopicConcentration()Gets the value of topicConcentration or its default value.

New in version 2.0.0.

getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCheckpointInterval(value)Sets the value of checkpointInterval.

New in version 2.0.0.

setDocConcentration(value)Sets the value of docConcentration.

770 Chapter 3. API Reference

Page 775: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> algo = LDA().setDocConcentration([0.1, 0.2])>>> algo.getDocConcentration()[0.1..., 0.2...]

New in version 2.0.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 2.0.0.

setK(value)Sets the value of k.

>>> algo = LDA().setK(10)>>> algo.getK()10

New in version 2.0.0.

setKeepLastCheckpoint(value)Sets the value of keepLastCheckpoint.

Examples

>>> algo = LDA().setKeepLastCheckpoint(False)>>> algo.getKeepLastCheckpoint()False

New in version 2.0.0.

setLearningDecay(value)Sets the value of learningDecay .

Examples

>>> algo = LDA().setLearningDecay(0.1)>>> algo.getLearningDecay()0.1...

New in version 2.0.0.

setLearningOffset(value)Sets the value of learningOffset.

3.3. ML 771

Page 776: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> algo = LDA().setLearningOffset(100)>>> algo.getLearningOffset()100.0

New in version 2.0.0.

setMaxIter(value)Sets the value of maxIter.

New in version 2.0.0.

setOptimizeDocConcentration(value)Sets the value of optimizeDocConcentration.

Examples

>>> algo = LDA().setOptimizeDocConcentration(True)>>> algo.getOptimizeDocConcentration()True

New in version 2.0.0.

setOptimizer(value)Sets the value of optimizer. Currently only support ‘em’ and ‘online’.

Examples

>>> algo = LDA().setOptimizer("em")>>> algo.getOptimizer()'em'

New in version 2.0.0.

setParams(self, \*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10, k=10,optimizer="online", learningOffset=1024.0, learningDecay=0.51, subsamplingRate=0.05,optimizeDocConcentration=True, docConcentration=None, topicConcentration=None,topicDistributionCol="topicDistribution", keepLastCheckpoint=True)

Sets params for LDA.

New in version 2.0.0.

setSeed(value)Sets the value of seed.

New in version 2.0.0.

setSubsamplingRate(value)Sets the value of subsamplingRate.

772 Chapter 3. API Reference

Page 777: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> algo = LDA().setSubsamplingRate(0.1)>>> algo.getSubsamplingRate()0.1...

New in version 2.0.0.

setTopicConcentration(value)Sets the value of topicConcentration.

Examples

>>> algo = LDA().setTopicConcentration(0.5)>>> algo.getTopicConcentration()0.5...

New in version 2.0.0.

setTopicDistributionCol(value)Sets the value of topicDistributionCol.

Examples

>>> algo = LDA().setTopicDistributionCol("topicDistributionCol")>>> algo.getTopicDistributionCol()'topicDistributionCol'

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')

keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')

learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')

learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')

optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

3.3. ML 773

Page 778: pyspark Documentation

pyspark Documentation, Release master

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')

topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')

topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')

LDAModel

class pyspark.ml.clustering.LDAModel(java_model=None)Latent Dirichlet Allocation (LDA) model. This abstraction permits for different underlying representations,including local and distributed data structures.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

describeTopics([maxTermsPerTopic]) Return the topics described by their top-weightedterms.

estimatedDocConcentration() Value for LDA.docConcentration estimatedfrom data.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.

getDocConcentration() Gets the value of docConcentration or its de-fault value.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its

default value.getLearningDecay() Gets the value of learningDecay or its default

value.getLearningOffset() Gets the value of learningOffset or its default

value.getMaxIter() Gets the value of maxIter or its default value.

continues on next page

774 Chapter 3. API Reference

Page 779: pyspark Documentation

pyspark Documentation, Release master

Table 272 – continued from previous pagegetOptimizeDocConcentration() Gets the value of

optimizeDocConcentration or its defaultvalue.

getOptimizer() Gets the value of optimizer or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getTopicConcentration() Gets the value of topicConcentration or its

default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its

default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isDistributed() Indicates whether this instance is of type Distribut-

edLDAModelisSet(param) Checks whether a param is explicitly set by user.logLikelihood(dataset) Calculates a lower bound on the log likelihood of the

entire corpus.logPerplexity(dataset) Calculate an upper bound on perplexity.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setSeed(value) Sets the value of seed.setTopicDistributionCol(value) Sets the value of topicDistributionCol.topicsMatrix() Inferred topics, where each topic is represented by a

distribution over terms.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.vocabSize() Vocabulary size (number of terms or words in the

vocabulary)

Attributes

checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffsetmaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.

continues on next page

3.3. ML 775

Page 780: pyspark Documentation

pyspark Documentation, Release master

Table 273 – continued from previous pageseedsubsamplingRatetopicConcentrationtopicDistributionCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

describeTopics(maxTermsPerTopic=10)Return the topics described by their top-weighted terms.

New in version 2.0.0.

estimatedDocConcentration()Value for LDA.docConcentration estimated from data. If Online LDA was used and LDA.optimizeDocConcentration was set to false, then this returns the fixed (given) value for the LDA.docConcentration parameter.

New in version 2.0.0.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getDocConcentration()Gets the value of docConcentration or its default value.

776 Chapter 3. API Reference

Page 781: pyspark Documentation

pyspark Documentation, Release master

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k or its default value.

New in version 2.0.0.

getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.

New in version 2.0.0.

getLearningDecay()Gets the value of learningDecay or its default value.

New in version 2.0.0.

getLearningOffset()Gets the value of learningOffset or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.

New in version 2.0.0.

getOptimizer()Gets the value of optimizer or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 2.0.0.

getTopicConcentration()Gets the value of topicConcentration or its default value.

New in version 2.0.0.

getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 777

Page 782: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isDistributed()Indicates whether this instance is of type DistributedLDAModel

New in version 2.0.0.

isSet(param)Checks whether a param is explicitly set by user.

logLikelihood(dataset)Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDApaper (Hoffman et al., 2010).

Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.

New in version 2.0.0.

logPerplexity(dataset)Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper(Hoffman et al., 2010).

Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.

New in version 2.0.0.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

New in version 3.0.0.

setTopicDistributionCol(value)Sets the value of topicDistributionCol.

New in version 3.0.0.

topicsMatrix()Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of sizevocabSize x k, where each column is a topic. No guarantees are given about the ordering of the topics.

778 Chapter 3. API Reference

Page 783: pyspark Documentation

pyspark Documentation, Release master

Warning: If this model is actually a DistributedLDAModel instance produced by theExpectation-Maximization (“em”) optimizer, then this method could involve collecting a large amountof data to the driver (on the order of vocabSize x k).

New in version 2.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

vocabSize()Vocabulary size (number of terms or words in the vocabulary)

New in version 2.0.0.

Attributes Documentation

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')

keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')

learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')

learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')

optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')

topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')

topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')

3.3. ML 779

Page 784: pyspark Documentation

pyspark Documentation, Release master

LocalLDAModel

class pyspark.ml.clustering.LocalLDAModel(java_model=None)Local (non-distributed) model fitted by LDA. This model stores the inferred topics only; it does not store infoabout the training dataset.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

describeTopics([maxTermsPerTopic]) Return the topics described by their top-weightedterms.

estimatedDocConcentration() Value for LDA.docConcentration estimatedfrom data.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.

getDocConcentration() Gets the value of docConcentration or its de-fault value.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its

default value.getLearningDecay() Gets the value of learningDecay or its default

value.getLearningOffset() Gets the value of learningOffset or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOptimizeDocConcentration() Gets the value of

optimizeDocConcentration or its defaultvalue.

getOptimizer() Gets the value of optimizer or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.

continues on next page

780 Chapter 3. API Reference

Page 785: pyspark Documentation

pyspark Documentation, Release master

Table 274 – continued from previous pagegetSubsamplingRate() Gets the value of subsamplingRate or its default

value.getTopicConcentration() Gets the value of topicConcentration or its

default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its

default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isDistributed() Indicates whether this instance is of type Distribut-

edLDAModelisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).logLikelihood(dataset) Calculates a lower bound on the log likelihood of the

entire corpus.logPerplexity(dataset) Calculate an upper bound on perplexity.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setSeed(value) Sets the value of seed.setTopicDistributionCol(value) Sets the value of topicDistributionCol.topicsMatrix() Inferred topics, where each topic is represented by a

distribution over terms.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.vocabSize() Vocabulary size (number of terms or words in the

vocabulary)write() Returns an MLWriter instance for this ML instance.

Attributes

checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffsetmaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.seedsubsamplingRate

continues on next page

3.3. ML 781

Page 786: pyspark Documentation

pyspark Documentation, Release master

Table 275 – continued from previous pagetopicConcentrationtopicDistributionCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

describeTopics(maxTermsPerTopic=10)Return the topics described by their top-weighted terms.

New in version 2.0.0.

estimatedDocConcentration()Value for LDA.docConcentration estimated from data. If Online LDA was used and LDA.optimizeDocConcentration was set to false, then this returns the fixed (given) value for the LDA.docConcentration parameter.

New in version 2.0.0.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getDocConcentration()Gets the value of docConcentration or its default value.

New in version 2.0.0.

782 Chapter 3. API Reference

Page 787: pyspark Documentation

pyspark Documentation, Release master

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k or its default value.

New in version 2.0.0.

getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.

New in version 2.0.0.

getLearningDecay()Gets the value of learningDecay or its default value.

New in version 2.0.0.

getLearningOffset()Gets the value of learningOffset or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.

New in version 2.0.0.

getOptimizer()Gets the value of optimizer or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 2.0.0.

getTopicConcentration()Gets the value of topicConcentration or its default value.

New in version 2.0.0.

getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

3.3. ML 783

Page 788: pyspark Documentation

pyspark Documentation, Release master

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isDistributed()Indicates whether this instance is of type DistributedLDAModel

New in version 2.0.0.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

logLikelihood(dataset)Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDApaper (Hoffman et al., 2010).

Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.

New in version 2.0.0.

logPerplexity(dataset)Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper(Hoffman et al., 2010).

Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.

New in version 2.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

New in version 3.0.0.

setTopicDistributionCol(value)Sets the value of topicDistributionCol.

New in version 3.0.0.

784 Chapter 3. API Reference

Page 789: pyspark Documentation

pyspark Documentation, Release master

topicsMatrix()Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of sizevocabSize x k, where each column is a topic. No guarantees are given about the ordering of the topics.

Warning: If this model is actually a DistributedLDAModel instance produced by theExpectation-Maximization (“em”) optimizer, then this method could involve collecting a large amountof data to the driver (on the order of vocabSize x k).

New in version 2.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

vocabSize()Vocabulary size (number of terms or words in the vocabulary)

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')

keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')

learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')

learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')

optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')

3.3. ML 785

Page 790: pyspark Documentation

pyspark Documentation, Release master

topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')

topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')

DistributedLDAModel

class pyspark.ml.clustering.DistributedLDAModel(java_model=None)Distributed model fitted by LDA. This type of model is currently only produced by Expectation-Maximization(EM).

This model stores the inferred topics, the full training dataset, and the topic distribution for each training docu-ment.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

describeTopics([maxTermsPerTopic]) Return the topics described by their top-weightedterms.

estimatedDocConcentration() Value for LDA.docConcentration estimatedfrom data.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCheckpointFiles() If using checkpointing and LDA.keepLastCheckpoint is set to true, thenthere may be saved checkpoint files.

getCheckpointInterval() Gets the value of checkpointInterval or its defaultvalue.

getDocConcentration() Gets the value of docConcentration or its de-fault value.

getFeaturesCol() Gets the value of featuresCol or its default value.getK() Gets the value of k or its default value.getKeepLastCheckpoint() Gets the value of keepLastCheckpoint or its

default value.getLearningDecay() Gets the value of learningDecay or its default

value.getLearningOffset() Gets the value of learningOffset or its default

value.getMaxIter() Gets the value of maxIter or its default value.

continues on next page

786 Chapter 3. API Reference

Page 791: pyspark Documentation

pyspark Documentation, Release master

Table 276 – continued from previous pagegetOptimizeDocConcentration() Gets the value of

optimizeDocConcentration or its defaultvalue.

getOptimizer() Gets the value of optimizer or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getTopicConcentration() Gets the value of topicConcentration or its

default value.getTopicDistributionCol() Gets the value of topicDistributionCol or its

default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isDistributed() Indicates whether this instance is of type Distribut-

edLDAModelisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).logLikelihood(dataset) Calculates a lower bound on the log likelihood of the

entire corpus.logPerplexity(dataset) Calculate an upper bound on perplexity.logPrior() Log probability of the current parameter estimate:

log P(topics, topic distributions for docs | alpha, eta)read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setSeed(value) Sets the value of seed.setTopicDistributionCol(value) Sets the value of topicDistributionCol.toLocal() Convert this distributed model to a local representa-

tion.topicsMatrix() Inferred topics, where each topic is represented by a

distribution over terms.trainingLogLikelihood() Log likelihood of the observed tokens in the train-

ing set, given the current parameter estimates: logP(docs | topics, topic distributions for docs, Dirichlethyperparameters)

transform(dataset[, params]) Transforms the input dataset with optional parame-ters.

vocabSize() Vocabulary size (number of terms or words in thevocabulary)

write() Returns an MLWriter instance for this ML instance.

3.3. ML 787

Page 792: pyspark Documentation

pyspark Documentation, Release master

Attributes

checkpointIntervaldocConcentrationfeaturesColkkeepLastCheckpointlearningDecaylearningOffsetmaxIteroptimizeDocConcentrationoptimizerparams Returns all params ordered by name.seedsubsamplingRatetopicConcentrationtopicDistributionCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

describeTopics(maxTermsPerTopic=10)Return the topics described by their top-weighted terms.

New in version 2.0.0.

estimatedDocConcentration()Value for LDA.docConcentration estimated from data. If Online LDA was used and LDA.optimizeDocConcentration was set to false, then this returns the fixed (given) value for the LDA.docConcentration parameter.

New in version 2.0.0.

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra

788 Chapter 3. API Reference

Page 793: pyspark Documentation

pyspark Documentation, Release master

values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCheckpointFiles()If using checkpointing and LDA.keepLastCheckpoint is set to true, then there may be saved check-point files. This method is provided so that users can manage those files.

New in version 2.0.0.

Returns

list List of checkpoint files from training

Notes

Removing the checkpoints can cause failures if a partition is lost and is needed by certainDistributedLDAModel methods. Reference counting will clean up the checkpoints when this modeland derivative data go out of scope.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getDocConcentration()Gets the value of docConcentration or its default value.

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getK()Gets the value of k or its default value.

New in version 2.0.0.

getKeepLastCheckpoint()Gets the value of keepLastCheckpoint or its default value.

New in version 2.0.0.

getLearningDecay()Gets the value of learningDecay or its default value.

New in version 2.0.0.

getLearningOffset()Gets the value of learningOffset or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getOptimizeDocConcentration()Gets the value of optimizeDocConcentration or its default value.

New in version 2.0.0.

3.3. ML 789

Page 794: pyspark Documentation

pyspark Documentation, Release master

getOptimizer()Gets the value of optimizer or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 2.0.0.

getTopicConcentration()Gets the value of topicConcentration or its default value.

New in version 2.0.0.

getTopicDistributionCol()Gets the value of topicDistributionCol or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isDistributed()Indicates whether this instance is of type DistributedLDAModel

New in version 2.0.0.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

logLikelihood(dataset)Calculates a lower bound on the log likelihood of the entire corpus. See Equation (16) in the Online LDApaper (Hoffman et al., 2010).

Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.

New in version 2.0.0.

790 Chapter 3. API Reference

Page 795: pyspark Documentation

pyspark Documentation, Release master

logPerplexity(dataset)Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper(Hoffman et al., 2010).

Warning: If this model is an instance of DistributedLDAModel (produced when optimizeris set to “em”), this involves collecting a large topicsMatrix() to the driver. This implementationmay be changed in the future.

New in version 2.0.0.

logPrior()Log probability of the current parameter estimate: log P(topics, topic distributions for docs | alpha, eta)

New in version 2.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

New in version 3.0.0.

setTopicDistributionCol(value)Sets the value of topicDistributionCol.

New in version 3.0.0.

toLocal()Convert this distributed model to a local representation. This discards info about the training dataset.

Warning: This involves collecting a large topicsMatrix() to the driver.

New in version 2.0.0.

topicsMatrix()Inferred topics, where each topic is represented by a distribution over terms. This is a matrix of sizevocabSize x k, where each column is a topic. No guarantees are given about the ordering of the topics.

Warning: If this model is actually a DistributedLDAModel instance produced by theExpectation-Maximization (“em”) optimizer, then this method could involve collecting a large amountof data to the driver (on the order of vocabSize x k).

New in version 2.0.0.

3.3. ML 791

Page 796: pyspark Documentation

pyspark Documentation, Release master

trainingLogLikelihood()Log likelihood of the observed tokens in the training set, given the current parameter estimates: log P(docs| topics, topic distributions for docs, Dirichlet hyperparameters)

Notes

• This excludes the prior; for that, use logPrior().

• Even with logPrior(), this is NOT the same as the data log likelihood given the hyperparam-eters.

• This is computed from the topic distributions computed during training. If you calllogLikelihood() on the same training dataset, the topic distributions will be computedagain, possibly giving different results.

New in version 2.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

vocabSize()Vocabulary size (number of terms or words in the vocabulary)

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

docConcentration = Param(parent='undefined', name='docConcentration', doc='Concentration parameter (commonly named "alpha") for the prior placed on documents\' distributions over topics ("theta").')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

k = Param(parent='undefined', name='k', doc='The number of topics (clusters) to infer. Must be > 1.')

keepLastCheckpoint = Param(parent='undefined', name='keepLastCheckpoint', doc='(For EM optimizer) If using checkpointing, this indicates whether to keep the last checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with care.')

learningDecay = Param(parent='undefined', name='learningDecay', doc='Learning rate, set as anexponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic convergence.')

learningOffset = Param(parent='undefined', name='learningOffset', doc='A (positive) learning parameter that downweights early iterations. Larger values make early iterations count less')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

optimizeDocConcentration = Param(parent='undefined', name='optimizeDocConcentration', doc='Indicates whether the docConcentration (Dirichlet parameter for document-topic distribution) will be optimized during training.')

optimizer = Param(parent='undefined', name='optimizer', doc='Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em')

792 Chapter 3. API Reference

Page 797: pyspark Documentation

pyspark Documentation, Release master

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].')

topicConcentration = Param(parent='undefined', name='topicConcentration', doc='Concentration parameter (commonly named "beta" or "eta") for the prior placed on topic\' distributions over terms.')

topicDistributionCol = Param(parent='undefined', name='topicDistributionCol', doc='Output column with estimates of the topic mixture distribution for each document (often called "theta" in the literature). Returns a vector of zeros for an empty document.')

PowerIterationClustering

class pyspark.ml.clustering.PowerIterationClustering(*args, **kwargs)Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen. Fromthe abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on anormalized pair-wise similarity matrix of the data.

This class is not yet an Estimator/Transformer, use assignClusters() method to run the PowerItera-tionClustering algorithm.

New in version 2.4.0.

Notes

See Wikipedia on Spectral clustering

Examples

>>> data = [(1, 0, 0.5),... (2, 0, 0.5), (2, 1, 0.7),... (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9),... (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1),... (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)]>>> df = spark.createDataFrame(data).toDF("src", "dst", "weight").repartition(1)>>> pic = PowerIterationClustering(k=2, weightCol="weight")>>> pic.setMaxIter(40)PowerIterationClustering...>>> assignments = pic.assignClusters(df)>>> assignments.sort(assignments.id).show(truncate=False)+---+-------+|id |cluster|+---+-------+|0 |0 ||1 |0 ||2 |0 ||3 |0 ||4 |0 ||5 |1 |+---+-------+...>>> pic_path = temp_path + "/pic">>> pic.save(pic_path)>>> pic2 = PowerIterationClustering.load(pic_path)>>> pic2.getK()

(continues on next page)

3.3. ML 793

Page 798: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

2>>> pic2.getMaxIter()40>>> pic2.assignClusters(df).take(6) == assignments.take(6)True

Methods

assignClusters(dataset) Run the PIC algorithm and returns a cluster assign-ment for each input vertex.

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getDstCol() Gets the value of dstCol or its default value.getInitMode() Gets the value of initMode or its default value.getK() Gets the value of k or its default value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSrcCol() Gets the value of srcCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDstCol(value) Sets the value of dstCol.setInitMode(value) Sets the value of initMode.setK(value) Sets the value of k.

continues on next page

794 Chapter 3. API Reference

Page 799: pyspark Documentation

pyspark Documentation, Release master

Table 278 – continued from previous pagesetMaxIter(value) Sets the value of maxIter.setParams(self, \*[, k, maxIter, initMode, . . . ]) Sets params for PowerIterationClustering.setSrcCol(value) Sets the value of srcCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

dstColinitModekmaxIterparams Returns all params ordered by name.srcColweightCol

Methods Documentation

assignClusters(dataset)Run the PIC algorithm and returns a cluster assignment for each input vertex.

Parameters

dataset [pyspark.sql.DataFrame] A dataset with columns src, dst, weight represent-ing the affinity matrix, which is the matrix A in the PIC paper. Suppose the src columnvalue is i, the dst column value is j, the weight column value is similarity s„ij„ which mustbe nonnegative. This is a symmetric matrix and hence s„ij„ = s„ji„. For any (i, j) withnonzero similarity, there should be either (i, j, s„ij„) or (j, i, s„ji„) in the input. Rows withi = j are ignored, because we assume s„ij„ = 0.0.

Returns

pyspark.sql.DataFrame A dataset that contains columns of vertex id and the corre-sponding cluster for the id. The schema of it will be: - id: Long - cluster: Int

New in version 2.4.0: ..

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

3.3. ML 795

Page 800: pyspark Documentation

pyspark Documentation, Release master

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getDstCol()Gets the value of dstCol or its default value.

New in version 2.4.0.

getInitMode()Gets the value of initMode or its default value.

New in version 2.4.0.

getK()Gets the value of k or its default value.

New in version 2.4.0.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSrcCol()Gets the value of srcCol or its default value.

New in version 2.4.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

796 Chapter 3. API Reference

Page 801: pyspark Documentation

pyspark Documentation, Release master

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDstCol(value)Sets the value of dstCol.

New in version 2.4.0.

setInitMode(value)Sets the value of initMode.

New in version 2.4.0.

setK(value)Sets the value of k.

New in version 2.4.0.

setMaxIter(value)Sets the value of maxIter.

New in version 2.4.0.

setParams(self, \*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weight-Col=None)

Sets params for PowerIterationClustering.

New in version 2.4.0.

setSrcCol(value)Sets the value of srcCol.

New in version 2.4.0.

setWeightCol(value)Sets the value of weightCol.

New in version 2.4.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

dstCol = Param(parent='undefined', name='dstCol', doc='Name of the input column for destination vertex IDs.')

initMode = Param(parent='undefined', name='initMode', doc="The initialization algorithm. This can be either 'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum of similarities with other vertices. Supported options: 'random' and 'degree'.")

k = Param(parent='undefined', name='k', doc='The number of clusters to create. Must be > 1.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

srcCol = Param(parent='undefined', name='srcCol', doc='Name of the input column for source vertex IDs.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 797

Page 802: pyspark Documentation

pyspark Documentation, Release master

3.3.6 ML Functions

array_to_vector(col) Converts a column of array of numeric type into a col-umn of dense vectors in MLlib

vector_to_array(col[, dtype]) Converts a column of MLlib sparse/dense vectors into acolumn of dense arrays.

pyspark.ml.functions.array_to_vector

pyspark.ml.functions.array_to_vector(col)Converts a column of array of numeric type into a column of dense vectors in MLlib

New in version 3.1.0.

Parameters

col [pyspark.sql.Column or str] Input column

Returns

pyspark.sql.Column The converted column of MLlib dense vectors.

Examples

>>> from pyspark.ml.functions import array_to_vector>>> df1 = spark.createDataFrame([([1.5, 2.5],),], schema='v1 array<double>')>>> df1.select(array_to_vector('v1').alias('vec1')).collect()[Row(vec1=DenseVector([1.5, 2.5]))]>>> df2 = spark.createDataFrame([([1.5, 3.5],),], schema='v1 array<float>')>>> df2.select(array_to_vector('v1').alias('vec1')).collect()[Row(vec1=DenseVector([1.5, 3.5]))]>>> df3 = spark.createDataFrame([([1, 3],),], schema='v1 array<int>')>>> df3.select(array_to_vector('v1').alias('vec1')).collect()[Row(vec1=DenseVector([1.0, 3.0]))]

pyspark.ml.functions.vector_to_array

pyspark.ml.functions.vector_to_array(col, dtype='float64')Converts a column of MLlib sparse/dense vectors into a column of dense arrays.

New in version 3.0.0.

Parameters

col [pyspark.sql.Column or str] Input column

dtype [str, optional] The data type of the output array. Valid values: “float64” or “float32”.

Returns

pyspark.sql.Column The converted column of dense arrays.

798 Chapter 3. API Reference

Page 803: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.functions import vector_to_array>>> from pyspark.mllib.linalg import Vectors as OldVectors>>> df = spark.createDataFrame([... (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)),... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]),... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))],... ["vec", "oldVec"])>>> df1 = df.select(vector_to_array("vec").alias("vec"),... vector_to_array("oldVec").alias("oldVec"))>>> df1.collect()[Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]

>>> df2 = df.select(vector_to_array("vec", "float32").alias("vec"),... vector_to_array("oldVec", "float32").alias("oldVec"))>>> df2.collect()[Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]

>>> df1.schema.fields[StructField(vec,ArrayType(DoubleType,false),false),StructField(oldVec,ArrayType(DoubleType,false),false)]>>> df2.schema.fields[StructField(vec,ArrayType(FloatType,false),false),StructField(oldVec,ArrayType(FloatType,false),false)]

3.3.7 Vector and Matrix

Vector()DenseVector(ar) A dense vector represented by a value array.SparseVector(size, *args) A simple sparse vector class for passing data to MLlib.Vectors() Factory methods for working with vectors.Matrix(numRows, numCols[, isTransposed])DenseMatrix(numRows, numCols, values[, . . . ]) Column-major dense matrix.SparseMatrix(numRows, numCols, colPtrs, . . . ) Sparse Matrix stored in CSC format.Matrices()

Vector

class pyspark.ml.linalg.Vector

3.3. ML 799

Page 804: pyspark Documentation

pyspark Documentation, Release master

Methods

toArray() Convert the vector into an numpy.ndarray

Methods Documentation

toArray()Convert the vector into an numpy.ndarray

Returns numpy.ndarray

DenseVector

class pyspark.ml.linalg.DenseVector(ar)A dense vector represented by a value array. We use numpy array for storage and arithmetics will be delegatedto the underlying numpy array.

Examples

>>> v = Vectors.dense([1.0, 2.0])>>> u = Vectors.dense([3.0, 4.0])>>> v + uDenseVector([4.0, 6.0])>>> 2 - vDenseVector([1.0, 0.0])>>> v / 2DenseVector([0.5, 1.0])>>> v * uDenseVector([3.0, 8.0])>>> u / vDenseVector([3.0, 2.0])>>> u % 2DenseVector([1.0, 0.0])>>> -vDenseVector([-1.0, -2.0])

Methods

dot(other) Compute the dot product of two Vectors.norm(p) Calculates the norm of a DenseVector.numNonzeros() Number of nonzero elements.squared_distance(other) Squared distance of two Vectors.toArray() Returns the underlying numpy.ndarray

800 Chapter 3. API Reference

Page 805: pyspark Documentation

pyspark Documentation, Release master

Attributes

values Returns the underlying numpy.ndarray

Methods Documentation

dot(other)Compute the dot product of two Vectors. We support (Numpy array, list, SparseVector, or SciPy sparse)and a target NumPy array that is either 1- or 2-dimensional. Equivalent to calling numpy.dot of the twovectors.

Examples

>>> dense = DenseVector(array.array('d', [1., 2.]))>>> dense.dot(dense)5.0>>> dense.dot(SparseVector(2, [0, 1], [2., 1.]))4.0>>> dense.dot(range(1, 3))5.0>>> dense.dot(np.array(range(1, 3)))5.0>>> dense.dot([1.,])Traceback (most recent call last):

...AssertionError: dimension mismatch>>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F'))array([ 5., 11.])>>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F'))Traceback (most recent call last):

...AssertionError: dimension mismatch

norm(p)Calculates the norm of a DenseVector.

Examples

>>> a = DenseVector([0, -1, 2, -3])>>> a.norm(2)3.7...>>> a.norm(1)6.0

numNonzeros()Number of nonzero elements. This scans all active values and count non zeros

squared_distance(other)Squared distance of two Vectors.

3.3. ML 801

Page 806: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> dense1 = DenseVector(array.array('d', [1., 2.]))>>> dense1.squared_distance(dense1)0.0>>> dense2 = np.array([2., 1.])>>> dense1.squared_distance(dense2)2.0>>> dense3 = [2., 1.]>>> dense1.squared_distance(dense3)2.0>>> sparse1 = SparseVector(2, [0, 1], [2., 1.])>>> dense1.squared_distance(sparse1)2.0>>> dense1.squared_distance([1.,])Traceback (most recent call last):

...AssertionError: dimension mismatch>>> dense1.squared_distance(SparseVector(1, [0,], [1.,]))Traceback (most recent call last):

...AssertionError: dimension mismatch

toArray()Returns the underlying numpy.ndarray

Attributes Documentation

valuesReturns the underlying numpy.ndarray

SparseVector

class pyspark.ml.linalg.SparseVector(size, *args)A simple sparse vector class for passing data to MLlib. Users may alternatively pass SciPy’s {scipy.sparse} datatypes.

Methods

dot(other) Dot product with a SparseVector or 1- or 2-dimensional Numpy array.

norm(p) Calculates the norm of a SparseVector.numNonzeros() Number of nonzero elements.squared_distance(other) Squared distance from a SparseVector or 1-

dimensional NumPy array.toArray() Returns a copy of this SparseVector as a 1-

dimensional numpy.ndarray.

802 Chapter 3. API Reference

Page 807: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

dot(other)Dot product with a SparseVector or 1- or 2-dimensional Numpy array.

Examples

>>> a = SparseVector(4, [1, 3], [3.0, 4.0])>>> a.dot(a)25.0>>> a.dot(array.array('d', [1., 2., 3., 4.]))22.0>>> b = SparseVector(4, [2], [1.0])>>> a.dot(b)0.0>>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))array([ 22., 22.])>>> a.dot([1., 2., 3.])Traceback (most recent call last):

...AssertionError: dimension mismatch>>> a.dot(np.array([1., 2.]))Traceback (most recent call last):

...AssertionError: dimension mismatch>>> a.dot(DenseVector([1., 2.]))Traceback (most recent call last):

...AssertionError: dimension mismatch>>> a.dot(np.zeros((3, 2)))Traceback (most recent call last):

...AssertionError: dimension mismatch

norm(p)Calculates the norm of a SparseVector.

Examples

>>> a = SparseVector(4, [0, 1], [3., -4.])>>> a.norm(1)7.0>>> a.norm(2)5.0

numNonzeros()Number of nonzero elements. This scans all active values and count non zeros.

squared_distance(other)Squared distance from a SparseVector or 1-dimensional NumPy array.

3.3. ML 803

Page 808: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> a = SparseVector(4, [1, 3], [3.0, 4.0])>>> a.squared_distance(a)0.0>>> a.squared_distance(array.array('d', [1., 2., 3., 4.]))11.0>>> a.squared_distance(np.array([1., 2., 3., 4.]))11.0>>> b = SparseVector(4, [2], [1.0])>>> a.squared_distance(b)26.0>>> b.squared_distance(a)26.0>>> b.squared_distance([1., 2.])Traceback (most recent call last):

...AssertionError: dimension mismatch>>> b.squared_distance(SparseVector(3, [1,], [1.0,]))Traceback (most recent call last):

...AssertionError: dimension mismatch

toArray()Returns a copy of this SparseVector as a 1-dimensional numpy.ndarray.

Vectors

class pyspark.ml.linalg.VectorsFactory methods for working with vectors.

Notes

Dense vectors are simply represented as NumPy array objects, so there is no need to covert them for use inMLlib. For sparse vectors, the factory methods in this class create an MLlib-compatible type, or users can passin SciPy’s scipy.sparse column vectors.

Methods

dense(*elements) Create a dense vector of 64-bit floats from a Pythonlist or numbers.

norm(vector, p) Find norm of the given vector.sparse(size, *args) Create a sparse vector, using either a dictionary, a

list of (index, value) pairs, or two separate arrays ofindices and values (sorted by index).

squared_distance(v1, v2) Squared distance between two vectors.zeros(size)

804 Chapter 3. API Reference

Page 809: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

static dense(*elements)Create a dense vector of 64-bit floats from a Python list or numbers.

Examples

>>> Vectors.dense([1, 2, 3])DenseVector([1.0, 2.0, 3.0])>>> Vectors.dense(1.0, 2.0)DenseVector([1.0, 2.0])

static norm(vector, p)Find norm of the given vector.

static sparse(size, *args)Create a sparse vector, using either a dictionary, a list of (index, value) pairs, or two separate arrays ofindices and values (sorted by index).

Parameters

size [int] Size of the vector.

args Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indicesand values.

Examples

>>> Vectors.sparse(4, {1: 1.0, 3: 5.5})SparseVector(4, {1: 1.0, 3: 5.5})>>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)])SparseVector(4, {1: 1.0, 3: 5.5})>>> Vectors.sparse(4, [1, 3], [1.0, 5.5])SparseVector(4, {1: 1.0, 3: 5.5})

static squared_distance(v1, v2)Squared distance between two vectors. a and b can be of type SparseVector, DenseVector, np.ndarray orarray.array.

Examples

>>> a = Vectors.sparse(4, [(0, 1), (3, 4)])>>> b = Vectors.dense([2, 5, 4, 1])>>> a.squared_distance(b)51.0

static zeros(size)

3.3. ML 805

Page 810: pyspark Documentation

pyspark Documentation, Release master

Matrix

class pyspark.ml.linalg.Matrix(numRows, numCols, isTransposed=False)

Methods

toArray() Returns its elements in a numpy.ndarray.

Methods Documentation

toArray()Returns its elements in a numpy.ndarray.

DenseMatrix

class pyspark.ml.linalg.DenseMatrix(numRows, numCols, values, isTransposed=False)Column-major dense matrix.

Methods

toArray() Return a numpy.ndarraytoSparse() Convert to SparseMatrix

Methods Documentation

toArray()Return a numpy.ndarray

Examples

>>> m = DenseMatrix(2, 2, range(4))>>> m.toArray()array([[ 0., 2.],

[ 1., 3.]])

toSparse()Convert to SparseMatrix

806 Chapter 3. API Reference

Page 811: pyspark Documentation

pyspark Documentation, Release master

SparseMatrix

class pyspark.ml.linalg.SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, is-Transposed=False)

Sparse Matrix stored in CSC format.

Methods

toArray() Return a numpy.ndarraytoDense()

Methods Documentation

toArray()Return a numpy.ndarray

toDense()

Matrices

class pyspark.ml.linalg.Matrices

Methods

dense(numRows, numCols, values) Create a DenseMatrixsparse(numRows, numCols, colPtrs, . . . ) Create a SparseMatrix

Methods Documentation

static dense(numRows, numCols, values)Create a DenseMatrix

static sparse(numRows, numCols, colPtrs, rowIndices, values)Create a SparseMatrix

3.3.8 Recommendation

ALS(*args, **kwargs) Alternating Least Squares (ALS) matrix factorization.ALSModel([java_model]) Model fitted by ALS.

3.3. ML 807

Page 812: pyspark Documentation

pyspark Documentation, Release master

ALS

class pyspark.ml.recommendation.ALS(*args, **kwargs)Alternating Least Squares (ALS) matrix factorization.

ALS attempts to estimate the ratings matrix R as the product of two lower-rank matrices, X and Y, i.e. X *Yt = R. Typically these approximations are called ‘factor’ matrices. The general approach is iterative. Duringeach iteration, one of the factor matrices is held constant, while the other is solved for using least squares. Thenewly-solved factor matrix is then held constant while solving for the other factor matrix.

This is a blocked implementation of the ALS factorization algorithm that groups the two sets of factors (referredto as “users” and “products”) into blocks and reduces communication by only sending one copy of each uservector to each product block on each iteration, and only for the product blocks that need that user’s featurevector. This is achieved by pre-computing some information about the ratings matrix to determine the “out-links” of each user (which blocks of products it will contribute to) and “in-link” information for each product(which of the feature vectors it receives from each user block it will depend on). This allows us to send only anarray of feature vectors between each user block and product block, and have the product block find the users’ratings and update the products based on these messages.

For implicit preference data, the algorithm used is based on “Collaborative Filtering for Implicit FeedbackDatasets”,, adapted for the blocked approach used here.

Essentially instead of finding the low-rank approximations to the rating matrix R, this finds the approximationsfor a preference matrix P where the elements of P are 1 if r > 0 and 0 if r <= 0. The ratings then act as‘confidence’ values related to strength of indicated user preferences rather than explicit ratings given to items.

New in version 1.4.0.

Notes

The input rating dataframe to the ALS implementation should be deterministic. Nondeterministic data cancause failure during fitting ALS model. For example, an order-sensitive operation like sampling after a repar-tition makes dataframe output nondeterministic, like df.repartition(2).sample(False, 0.5, 1618). Checkpointingsampled dataframe or adding a sort before sampling can help make the dataframe deterministic.

Examples

>>> df = spark.createDataFrame(... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2,→˓5.0)],... ["user", "item", "rating"])>>> als = ALS(rank=10, seed=0)>>> als.setMaxIter(5)ALS...>>> als.getMaxIter()5>>> als.setRegParam(0.1)ALS...>>> als.getRegParam()0.1>>> als.clear(als.regParam)>>> model = als.fit(df)>>> model.getBlockSize()4096>>> model.getUserCol()

(continues on next page)

808 Chapter 3. API Reference

Page 813: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

'user'>>> model.setUserCol("user")ALSModel...>>> model.getItemCol()'item'>>> model.setPredictionCol("newPrediction")ALS...>>> model.rank10>>> model.userFactors.orderBy("id").collect()[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]>>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])>>> predictions[0]Row(user=0, item=2, newPrediction=0.692910...)>>> predictions[1]Row(user=1, item=0, newPrediction=3.473569...)>>> predictions[2]Row(user=2, item=0, newPrediction=-0.899198...)>>> user_recs = model.recommendForAllUsers(3)>>> user_recs.where(user_recs.user == 0) .select("recommendations.item",→˓"recommendations.rating").collect()[Row(item=[0, 1, 2], rating=[3.910..., 1.997..., 0.692...])]>>> item_recs = model.recommendForAllItems(3)>>> item_recs.where(item_recs.item == 2) .select("recommendations.user",→˓"recommendations.rating").collect()[Row(user=[2, 1, 0], rating=[4.892..., 3.991..., 0.692...])]>>> user_subset = df.where(df.user == 2)>>> user_subset_recs = model.recommendForUserSubset(user_subset, 3)>>> user_subset_recs.select("recommendations.item", "recommendations.rating").→˓first()Row(item=[2, 1, 0], rating=[4.892..., 1.076..., -0.899...])>>> item_subset = df.where(df.item == 0)>>> item_subset_recs = model.recommendForItemSubset(item_subset, 3)>>> item_subset_recs.select("recommendations.user", "recommendations.rating").→˓first()Row(user=[0, 1, 2], rating=[3.910..., 3.473..., -0.899...])>>> als_path = temp_path + "/als">>> als.save(als_path)>>> als2 = ALS.load(als_path)>>> als.getMaxIter()5>>> model_path = temp_path + "/als_model">>> model.save(model_path)>>> model2 = ALSModel.load(model_path)>>> model.rank == model2.rankTrue>>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())True>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())True>>> model.transform(test).take(1) == model2.transform(test).take(1)True

3.3. ML 809

Page 814: pyspark Documentation

pyspark Documentation, Release master

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getAlpha() Gets the value of alpha or its default value.getBlockSize() Gets the value of blockSize or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getColdStartStrategy() Gets the value of coldStartStrategy or its default

value.getFinalStorageLevel() Gets the value of finalStorageLevel or its default

value.getImplicitPrefs() Gets the value of implicitPrefs or its default value.getIntermediateStorageLevel() Gets the value of intermediateStorageLevel or its de-

fault value.getItemCol() Gets the value of itemCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getNonnegative() Gets the value of nonnegative or its default value.getNumItemBlocks() Gets the value of numItemBlocks or its default value.getNumUserBlocks() Gets the value of numUserBlocks or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRank() Gets the value of rank or its default value.getRatingCol() Gets the value of ratingCol or its default value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getUserCol() Gets the value of userCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.continues on next page

810 Chapter 3. API Reference

Page 815: pyspark Documentation

pyspark Documentation, Release master

Table 292 – continued from previous pageisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAlpha(value) Sets the value of alpha.setBlockSize(value) Sets the value of blockSize.setCheckpointInterval(value) Sets the value of checkpointInterval.setColdStartStrategy(value) Sets the value of coldStartStrategy .setFinalStorageLevel(value) Sets the value of finalStorageLevel.setImplicitPrefs(value) Sets the value of implicitPrefs.setIntermediateStorageLevel(value) Sets the value of

intermediateStorageLevel.setItemCol(value) Sets the value of itemCol.setMaxIter(value) Sets the value of maxIter.setNonnegative(value) Sets the value of nonnegative.setNumBlocks(value) Sets both numUserBlocks and

numItemBlocks to the specific value.setNumItemBlocks(value) Sets the value of numItemBlocks.setNumUserBlocks(value) Sets the value of numUserBlocks.setParams(self, \*[, rank, maxIter, . . . ]) Sets params for ALS.setPredictionCol(value) Sets the value of predictionCol.setRank(value) Sets the value of rank.setRatingCol(value) Sets the value of ratingCol.setRegParam(value) Sets the value of regParam.setSeed(value) Sets the value of seed.setUserCol(value) Sets the value of userCol.write() Returns an MLWriter instance for this ML instance.

Attributes

alphablockSizecheckpointIntervalcoldStartStrategyfinalStorageLevelimplicitPrefsintermediateStorageLevelitemColmaxIternonnegativenumItemBlocksnumUserBlocksparams Returns all params ordered by name.predictionColrankratingCol

continues on next page

3.3. ML 811

Page 816: pyspark Documentation

pyspark Documentation, Release master

Table 293 – continued from previous pageregParamseeduserCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

812 Chapter 3. API Reference

Page 817: pyspark Documentation

pyspark Documentation, Release master

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAlpha()Gets the value of alpha or its default value.

New in version 1.4.0.

getBlockSize()Gets the value of blockSize or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getColdStartStrategy()Gets the value of coldStartStrategy or its default value.

New in version 2.2.0.

getFinalStorageLevel()Gets the value of finalStorageLevel or its default value.

New in version 2.0.0.

getImplicitPrefs()Gets the value of implicitPrefs or its default value.

New in version 1.4.0.

getIntermediateStorageLevel()Gets the value of intermediateStorageLevel or its default value.

New in version 2.0.0.

getItemCol()Gets the value of itemCol or its default value.

New in version 1.4.0.

getMaxIter()Gets the value of maxIter or its default value.

getNonnegative()Gets the value of nonnegative or its default value.

New in version 1.4.0.

getNumItemBlocks()Gets the value of numItemBlocks or its default value.

New in version 1.4.0.

getNumUserBlocks()Gets the value of numUserBlocks or its default value.

New in version 1.4.0.

3.3. ML 813

Page 818: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRank()Gets the value of rank or its default value.

New in version 1.4.0.

getRatingCol()Gets the value of ratingCol or its default value.

New in version 1.4.0.

getRegParam()Gets the value of regParam or its default value.

getSeed()Gets the value of seed or its default value.

getUserCol()Gets the value of userCol or its default value.

New in version 1.4.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAlpha(value)Sets the value of alpha.

New in version 1.4.0.

setBlockSize(value)Sets the value of blockSize.

New in version 3.0.0.

814 Chapter 3. API Reference

Page 819: pyspark Documentation

pyspark Documentation, Release master

setCheckpointInterval(value)Sets the value of checkpointInterval.

setColdStartStrategy(value)Sets the value of coldStartStrategy .

New in version 2.2.0.

setFinalStorageLevel(value)Sets the value of finalStorageLevel.

New in version 2.0.0.

setImplicitPrefs(value)Sets the value of implicitPrefs.

New in version 1.4.0.

setIntermediateStorageLevel(value)Sets the value of intermediateStorageLevel.

New in version 2.0.0.

setItemCol(value)Sets the value of itemCol.

New in version 1.4.0.

setMaxIter(value)Sets the value of maxIter.

setNonnegative(value)Sets the value of nonnegative.

New in version 1.4.0.

setNumBlocks(value)Sets both numUserBlocks and numItemBlocks to the specific value.

New in version 1.4.0.

setNumItemBlocks(value)Sets the value of numItemBlocks.

New in version 1.4.0.

setNumUserBlocks(value)Sets the value of numUserBlocks.

New in version 1.4.0.

setParams(self, \*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, rat-ingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStor-ageLevel="MEMORY_AND_DISK", finalStorageLevel="MEMORY_AND_DISK",coldStartStrategy="nan", blockSize=4096)

Sets params for ALS.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

setRank(value)Sets the value of rank.

3.3. ML 815

Page 820: pyspark Documentation

pyspark Documentation, Release master

New in version 1.4.0.

setRatingCol(value)Sets the value of ratingCol.

New in version 1.4.0.

setRegParam(value)Sets the value of regParam.

setSeed(value)Sets the value of seed.

setUserCol(value)Sets the value of userCol.

New in version 1.4.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

alpha = Param(parent='undefined', name='alpha', doc='alpha for implicit preference')

blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

coldStartStrategy = Param(parent='undefined', name='coldStartStrategy', doc="strategy for dealing with unknown or new users/items at prediction time. This may be useful in cross-validation or production scenarios, for handling user/item ids the model has not seen in the training data. Supported values: 'nan', 'drop'.")

finalStorageLevel = Param(parent='undefined', name='finalStorageLevel', doc='StorageLevel for ALS model factors.')

implicitPrefs = Param(parent='undefined', name='implicitPrefs', doc='whether to use implicit preference')

intermediateStorageLevel = Param(parent='undefined', name='intermediateStorageLevel', doc="StorageLevel for intermediate datasets. Cannot be 'NONE'.")

itemCol = Param(parent='undefined', name='itemCol', doc='column name for item ids. Ids must be within the integer value range.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

nonnegative = Param(parent='undefined', name='nonnegative', doc='whether to use nonnegative constraint for least squares')

numItemBlocks = Param(parent='undefined', name='numItemBlocks', doc='number of item blocks')

numUserBlocks = Param(parent='undefined', name='numUserBlocks', doc='number of user blocks')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

rank = Param(parent='undefined', name='rank', doc='rank of the factorization')

ratingCol = Param(parent='undefined', name='ratingCol', doc='column name for ratings')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

seed = Param(parent='undefined', name='seed', doc='random seed.')

userCol = Param(parent='undefined', name='userCol', doc='column name for user ids. Ids must be within the integer value range.')

816 Chapter 3. API Reference

Page 821: pyspark Documentation

pyspark Documentation, Release master

ALSModel

class pyspark.ml.recommendation.ALSModel(java_model=None)Model fitted by ALS.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBlockSize() Gets the value of blockSize or its default value.getColdStartStrategy() Gets the value of coldStartStrategy or its default

value.getItemCol() Gets the value of itemCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getUserCol() Gets the value of userCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.recommendForAllItems(numUsers) Returns top numUsers users recommended for each

item, for all items.recommendForAllUsers(numItems) Returns top numItems items recommended for each

user, for all users.recommendForItemSubset(dataset, nu-mUsers)

Returns top numUsers users recommended for eachitem id in the input data set.

recommendForUserSubset(dataset, nu-mItems)

Returns top numItems items recommended for eachuser id in the input data set.

continues on next page

3.3. ML 817

Page 822: pyspark Documentation

pyspark Documentation, Release master

Table 294 – continued from previous pagesave(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBlockSize(value) Sets the value of blockSize.setColdStartStrategy(value) Sets the value of coldStartStrategy .setItemCol(value) Sets the value of itemCol.setPredictionCol(value) Sets the value of predictionCol.setUserCol(value) Sets the value of userCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

blockSizecoldStartStrategyitemColitemFactors a DataFrame that stores item factors in two columns:

id and featuresparams Returns all params ordered by name.predictionColrank rank of the matrix factorization modeluserColuserFactors a DataFrame that stores user factors in two columns:

id and features

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra

818 Chapter 3. API Reference

Page 823: pyspark Documentation

pyspark Documentation, Release master

values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBlockSize()Gets the value of blockSize or its default value.

getColdStartStrategy()Gets the value of coldStartStrategy or its default value.

New in version 2.2.0.

getItemCol()Gets the value of itemCol or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getUserCol()Gets the value of userCol or its default value.

New in version 1.4.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

recommendForAllItems(numUsers)Returns top numUsers users recommended for each item, for all items.

New in version 2.2.0.

Parameters

numUsers [int] max number of recommendations for each item

3.3. ML 819

Page 824: pyspark Documentation

pyspark Documentation, Release master

Returns

pyspark.sql.DataFrame a DataFrame of (itemCol, recommendations), where recom-mendations are stored as an array of (userCol, rating) Rows.

recommendForAllUsers(numItems)Returns top numItems items recommended for each user, for all users.

New in version 2.2.0.

Parameters

numItems [int] max number of recommendations for each user

Returns

pyspark.sql.DataFrame a DataFrame of (userCol, recommendations), where recom-mendations are stored as an array of (itemCol, rating) Rows.

recommendForItemSubset(dataset, numUsers)Returns top numUsers users recommended for each item id in the input data set. Note that if there areduplicate ids in the input dataset, only one set of recommendations per unique id will be returned.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] a DataFrame containing a column of item ids. Thecolumn name must match itemCol.

numUsers [int] max number of recommendations for each item

Returns

pyspark.sql.DataFrame a DataFrame of (itemCol, recommendations), where recom-mendations are stored as an array of (userCol, rating) Rows.

recommendForUserSubset(dataset, numItems)Returns top numItems items recommended for each user id in the input data set. Note that if there areduplicate ids in the input dataset, only one set of recommendations per unique id will be returned.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] a DataFrame containing a column of user ids. Thecolumn name must match userCol.

numItems [int] max number of recommendations for each user

Returns

pyspark.sql.DataFrame a DataFrame of (userCol, recommendations), where recom-mendations are stored as an array of (itemCol, rating) Rows.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBlockSize(value)Sets the value of blockSize.

New in version 3.0.0.

820 Chapter 3. API Reference

Page 825: pyspark Documentation

pyspark Documentation, Release master

setColdStartStrategy(value)Sets the value of coldStartStrategy .

New in version 3.0.0.

setItemCol(value)Sets the value of itemCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setUserCol(value)Sets the value of userCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

blockSize = Param(parent='undefined', name='blockSize', doc='block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.')

coldStartStrategy = Param(parent='undefined', name='coldStartStrategy', doc="strategy for dealing with unknown or new users/items at prediction time. This may be useful in cross-validation or production scenarios, for handling user/item ids the model has not seen in the training data. Supported values: 'nan', 'drop'.")

itemCol = Param(parent='undefined', name='itemCol', doc='column name for item ids. Ids must be within the integer value range.')

itemFactorsa DataFrame that stores item factors in two columns: id and features

New in version 1.4.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

rankrank of the matrix factorization model

New in version 1.4.0.

userCol = Param(parent='undefined', name='userCol', doc='column name for user ids. Ids must be within the integer value range.')

3.3. ML 821

Page 826: pyspark Documentation

pyspark Documentation, Release master

userFactorsa DataFrame that stores user factors in two columns: id and features

New in version 1.4.0.

3.3.9 Regression

AFTSurvivalRegression(*args, **kwargs) Accelerated Failure Time (AFT) Model Survival Re-gression

AFTSurvivalRegressionModel([java_model]) Model fitted by AFTSurvivalRegression.DecisionTreeRegressor(*args, **kwargs) Decision tree learning algorithm for regression.It sup-

ports both continuous and categorical features..

DecisionTreeRegressionModel([java_model]) Model fitted by DecisionTreeRegressor.GBTRegressor(*args, **kwargs) Gradient-Boosted Trees (GBTs) learning algorithm for

regression.It supports both continuous and categoricalfeatures..

GBTRegressionModel([java_model]) Model fitted by GBTRegressor.GeneralizedLinearRegression(*args,**kwargs)

Generalized Linear Regression.

GeneralizedLinearRegressionModel([java_model])Model fitted by GeneralizedLinearRegression.GeneralizedLinearRegressionSummary([java_obj])Generalized linear regression results evaluated on a

dataset.GeneralizedLinearRegressionTrainingSummary([. . . ])Generalized linear regression training results.IsotonicRegression(*args, **kwargs) Currently implemented using parallelized pool adjacent

violators algorithm.IsotonicRegressionModel([java_model]) Model fitted by IsotonicRegression.LinearRegression(*args, **kwargs) Linear regression.LinearRegressionModel([java_model]) Model fitted by LinearRegression.LinearRegressionSummary([java_obj]) Linear regression results evaluated on a dataset.LinearRegressionTrainingSummary([java_obj])Linear regression training results.RandomForestRegressor(*args, **kwargs) Random Forest learning algorithm for regression.It sup-

ports both continuous and categorical features..

RandomForestRegressionModel([java_model]) Model fitted by RandomForestRegressor.FMRegressor(*args, **kwargs) Factorization Machines learning algorithm for regres-

sion.FMRegressionModel([java_model]) Model fitted by FMRegressor.

822 Chapter 3. API Reference

Page 827: pyspark Documentation

pyspark Documentation, Release master

AFTSurvivalRegression

class pyspark.ml.regression.AFTSurvivalRegression(*args, **kwargs)Accelerated Failure Time (AFT) Model Survival Regression

Fit a parametric AFT survival regression model based on the Weibull distribution of the survival time.

Notes

For more information see Wikipedia page on AFT Model

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0), 1.0),... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])>>> aftsr = AFTSurvivalRegression()>>> aftsr.setMaxIter(10)AFTSurvivalRegression...>>> aftsr.getMaxIter()10>>> aftsr.clear(aftsr.maxIter)>>> model = aftsr.fit(df)>>> model.getMaxBlockSizeInMB()0.0>>> model.setFeaturesCol("features")AFTSurvivalRegressionModel...>>> model.predict(Vectors.dense(6.3))1.0>>> model.predictQuantiles(Vectors.dense(6.3))DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.→˓6052])>>> model.transform(df).show()+-------+---------+------+----------+| label| features|censor|prediction|+-------+---------+------+----------+| 1.0| [1.0]| 1.0| 1.0||1.0E-40|(1,[],[])| 0.0| 1.0|+-------+---------+------+----------+...>>> aftsr_path = temp_path + "/aftsr">>> aftsr.save(aftsr_path)>>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)>>> aftsr2.getMaxIter()100>>> model_path = temp_path + "/aftsr_model">>> model.save(model_path)>>> model2 = AFTSurvivalRegressionModel.load(model_path)>>> model.coefficients == model2.coefficientsTrue>>> model.intercept == model2.interceptTrue>>> model.scale == model2.scaleTrue

(continues on next page)

3.3. ML 823

Page 828: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.transform(df).take(1) == model2.transform(df).take(1)True

New in version 1.6.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getCensorCol() Gets the value of censorCol or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getQuantileProbabilities() Gets the value of quantileProbabilities or its default

value.getQuantilesCol() Gets the value of quantilesCol or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.

continues on next page

824 Chapter 3. API Reference

Page 829: pyspark Documentation

pyspark Documentation, Release master

Table 297 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setCensorCol(value) Sets the value of censorCol.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, la-

belCol=”label”, predictionCol=”prediction”,fitIntercept=True, maxIter=100, tol=1E-6, cen-sorCol=”censor”, quantileProbabilities=[0.01,0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], quan-tilesCol=None, aggregationDepth=2, maxBlock-SizeInMB=0.0):

setPredictionCol(value) Sets the value of predictionCol.setQuantileProbabilities(value) Sets the value of quantileProbabilities.setQuantilesCol(value) Sets the value of quantilesCol.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthcensorColfeaturesColfitInterceptlabelColmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColquantileProbabilitiesquantilesColtol

3.3. ML 825

Page 830: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

826 Chapter 3. API Reference

Page 831: pyspark Documentation

pyspark Documentation, Release master

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getCensorCol()Gets the value of censorCol or its default value.

New in version 1.6.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getQuantileProbabilities()Gets the value of quantileProbabilities or its default value.

New in version 1.6.0.

getQuantilesCol()Gets the value of quantilesCol or its default value.

New in version 1.6.0.

getTol()Gets the value of tol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

3.3. ML 827

Page 832: pyspark Documentation

pyspark Documentation, Release master

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAggregationDepth(value)Sets the value of aggregationDepth.

New in version 2.1.0.

setCensorCol(value)Sets the value of censorCol.

New in version 1.6.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setFitIntercept(value)Sets the value of fitIntercept.

New in version 1.6.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.

New in version 3.1.0.

setMaxIter(value)Sets the value of maxIter.

New in version 1.6.0.

setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, fitInter-cept=True, maxIter=100, tol=1E-6, censorCol=”censor”, quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5,0.75, 0.9, 0.95, 0.99], quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0):

New in version 1.6.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setQuantileProbabilities(value)Sets the value of quantileProbabilities.

New in version 1.6.0.

828 Chapter 3. API Reference

Page 833: pyspark Documentation

pyspark Documentation, Release master

setQuantilesCol(value)Sets the value of quantilesCol.

New in version 1.6.0.

setTol(value)Sets the value of tol.

New in version 1.6.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

censorCol = Param(parent='undefined', name='censorCol', doc='censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

quantileProbabilities = Param(parent='undefined', name='quantileProbabilities', doc='quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.')

quantilesCol = Param(parent='undefined', name='quantilesCol', doc='quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

AFTSurvivalRegressionModel

class pyspark.ml.regression.AFTSurvivalRegressionModel(java_model=None)Model fitted by AFTSurvivalRegression.

New in version 1.6.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

continues on next page

3.3. ML 829

Page 834: pyspark Documentation

pyspark Documentation, Release master

Table 299 – continued from previous pageexplainParams() Returns the documentation of all params with their

optionally default values and user-supplied values.extractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getCensorCol() Gets the value of censorCol or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getQuantileProbabilities() Gets the value of quantileProbabilities or its default

value.getQuantilesCol() Gets the value of quantilesCol or its default value.getTol() Gets the value of tol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictQuantiles(features) Predicted Quantilesread() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.setQuantileProbabilities(value) Sets the value of quantileProbabilities.setQuantilesCol(value) Sets the value of quantilesCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

830 Chapter 3. API Reference

Page 835: pyspark Documentation

pyspark Documentation, Release master

Attributes

aggregationDepthcensorColcoefficients Model coefficients.featuresColfitInterceptintercept Model intercept.labelColmaxBlockSizeInMBmaxIternumFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColquantileProbabilitiesquantilesColscale Model scale parameter.tol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

3.3. ML 831

Page 836: pyspark Documentation

pyspark Documentation, Release master

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getCensorCol()Gets the value of censorCol or its default value.

New in version 1.6.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getQuantileProbabilities()Gets the value of quantileProbabilities or its default value.

New in version 1.6.0.

getQuantilesCol()Gets the value of quantilesCol or its default value.

New in version 1.6.0.

getTol()Gets the value of tol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

832 Chapter 3. API Reference

Page 837: pyspark Documentation

pyspark Documentation, Release master

predictQuantiles(features)Predicted Quantiles

New in version 2.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setQuantileProbabilities(value)Sets the value of quantileProbabilities.

New in version 3.0.0.

setQuantilesCol(value)Sets the value of quantilesCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

censorCol = Param(parent='undefined', name='censorCol', doc='censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.')

coefficientsModel coefficients.

New in version 2.0.0.

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

3.3. ML 833

Page 838: pyspark Documentation

pyspark Documentation, Release master

interceptModel intercept.

New in version 1.6.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

quantileProbabilities = Param(parent='undefined', name='quantileProbabilities', doc='quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.')

quantilesCol = Param(parent='undefined', name='quantilesCol', doc='quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.')

scaleModel scale parameter.

New in version 1.6.0.

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

DecisionTreeRegressor

class pyspark.ml.regression.DecisionTreeRegressor(*args, **kwargs)Decision tree learning algorithm for regression. It supports both continuous and categorical features.

New in version 1.4.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> dt = DecisionTreeRegressor(maxDepth=2)>>> dt.setVarianceCol("variance")DecisionTreeRegressor...>>> model = dt.fit(df)>>> model.getVarianceCol()'variance'>>> model.setLeafCol("leafId")DecisionTreeRegressionModel...>>> model.depth1>>> model.numNodes3>>> model.featureImportancesSparseVector(1, {0: 1.0})

(continues on next page)

834 Chapter 3. API Reference

Page 839: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.numFeatures1>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> result = model.transform(test0).head()>>> result.prediction0.0>>> model.predictLeaf(test0.head().features)0.0>>> result.leafId0.0>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> dtr_path = temp_path + "/dtr">>> dt.save(dtr_path)>>> dt2 = DecisionTreeRegressor.load(dtr_path)>>> dt2.getMaxDepth()2>>> model_path = temp_path + "/dtr_model">>> model.save(model_path)>>> model2 = DecisionTreeRegressionModel.load(model_path)>>> model.numNodes == model2.numNodesTrue>>> model.depth == model2.depthTrue>>> model.transform(test1).head().variance0.0>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> df3 = spark.createDataFrame([... (1.0, 0.2, Vectors.dense(1.0)),... (1.0, 0.8, Vectors.dense(1.0)),... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])>>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol=→˓"variance")>>> model3 = dt3.fit(df3)>>> print(model3.toDebugString)DecisionTreeRegressionModel...depth=1, numNodes=3...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

3.3. ML 835

Page 840: pyspark Documentation

pyspark Documentation, Release master

Table 301 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getVarianceCol() Gets the value of varianceCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.

continues on next page

836 Chapter 3. API Reference

Page 841: pyspark Documentation

pyspark Documentation, Release master

Table 301 – continued from previous pagesetMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of

minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for the DecisionTreeRegressor.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setVarianceCol(value) Sets the value of varianceCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

cacheNodeIdscheckpointIntervalfeaturesColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColseedsupportedImpuritiesvarianceColweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

3.3. ML 837

Page 842: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

838 Chapter 3. API Reference

Page 843: pyspark Documentation

pyspark Documentation, Release master

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getVarianceCol()Gets the value of varianceCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

3.3. ML 839

Page 844: pyspark Documentation

pyspark Documentation, Release master

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCacheNodeIds(value)Sets the value of cacheNodeIds.

New in version 1.4.0.

setCheckpointInterval(value)Sets the value of checkpointInterval.

New in version 1.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setImpurity(value)Sets the value of impurity .

New in version 1.4.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setMaxBins(value)Sets the value of maxBins.

New in version 1.4.0.

setMaxDepth(value)Sets the value of maxDepth.

New in version 1.4.0.

setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.

New in version 1.4.0.

setMinInfoGain(value)Sets the value of minInfoGain.

New in version 1.4.0.

setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.

New in version 1.4.0.

setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.

New in version 3.0.0.

840 Chapter 3. API Reference

Page 845: pyspark Documentation

pyspark Documentation, Release master

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",seed=None, varianceCol=None, weightCol=None, leafCol="", minWeightFractionPerN-ode=0.0)

Sets params for the DecisionTreeRegressor.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

setVarianceCol(value)Sets the value of varianceCol.

New in version 2.0.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

supportedImpurities = ['variance']

3.3. ML 841

Page 846: pyspark Documentation

pyspark Documentation, Release master

varianceCol = Param(parent='undefined', name='varianceCol', doc='column name for the biased sample variance of prediction.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

DecisionTreeRegressionModel

class pyspark.ml.regression.DecisionTreeRegressionModel(java_model=None)Model fitted by DecisionTreeRegressor.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getVarianceCol() Gets the value of varianceCol or its default value.getWeightCol() Gets the value of weightCol or its default value.

continues on next page

842 Chapter 3. API Reference

Page 847: pyspark Documentation

pyspark Documentation, Release master

Table 303 – continued from previous pagehasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the

feature vector.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.setVarianceCol(value) Sets the value of varianceCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

cacheNodeIdscheckpointIntervaldepth Return depth of the decision tree.featureImportances Estimate of the importance of each feature.featuresColimpuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumFeatures Returns the number of features the model was trained

on.numNodes Return number of nodes of the decision tree.params Returns all params ordered by name.predictionColseedsupportedImpuritiestoDebugString Full description of model.varianceColweightCol

3.3. ML 843

Page 848: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

844 Chapter 3. API Reference

Page 849: pyspark Documentation

pyspark Documentation, Release master

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getVarianceCol()Gets the value of varianceCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

3.3. ML 845

Page 850: pyspark Documentation

pyspark Documentation, Release master

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setVarianceCol(value)Sets the value of varianceCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

depthReturn depth of the decision tree.

New in version 1.5.0.

featureImportancesEstimate of the importance of each feature.

This generalizes the idea of “Gini” importance to other losses, following the explanation of Gini im-portance from “Random Forests” documentation by Leo Breiman and Adele Cutler, and following theimplementation from scikit-learn.

This feature importance is calculated as follows:

• importance(feature j) = sum (over nodes which split on feature j) of the gain, where gain is scaledby the number of instances passing through node

• Normalize importances for tree to sum to 1.

New in version 2.0.0.

846 Chapter 3. API Reference

Page 851: pyspark Documentation

pyspark Documentation, Release master

Notes

Feature importance for single decision trees can have high variance due to correlated predictor variables.Consider using a RandomForestRegressor to determine feature importance instead.

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

numNodesReturn number of nodes of the decision tree.

New in version 1.5.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

supportedImpurities = ['variance']

toDebugStringFull description of model.

New in version 2.0.0.

varianceCol = Param(parent='undefined', name='varianceCol', doc='column name for the biased sample variance of prediction.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

GBTRegressor

class pyspark.ml.regression.GBTRegressor(*args, **kwargs)Gradient-Boosted Trees (GBTs) learning algorithm for regression. It supports both continuous and categoricalfeatures.

New in version 1.4.0.

3.3. ML 847

Page 852: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId")>>> gbt.setMaxIter(5)GBTRegressor...>>> gbt.setMinWeightFractionPerNode(0.049)GBTRegressor...>>> gbt.getMaxIter()5>>> print(gbt.getImpurity())variance>>> print(gbt.getFeatureSubsetStrategy())all>>> model = gbt.fit(df)>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> model.numFeatures1>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictLeaf(test0.head().features)DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.leafIdDenseVector([0.0, 0.0, 0.0, 0.0, 0.0])>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction1.0>>> gbtr_path = temp_path + "gbtr">>> gbt.save(gbtr_path)>>> gbt2 = GBTRegressor.load(gbtr_path)>>> gbt2.getMaxDepth()2>>> model_path = temp_path + "gbtr_model">>> model.save(model_path)>>> model2 = GBTRegressionModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.treeWeights == model2.treeWeightsTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.trees[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],... ["label", "features"])

(continues on next page)

848 Chapter 3. API Reference

Page 853: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model.evaluateEachIteration(validation, "squared")[0.0, 0.0, 0.0, 0.0, 0.0]>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")>>> gbt.getValidationIndicatorCol()'validationIndicator'>>> gbt.getValidationTol()0.01

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.continues on next page

3.3. ML 849

Page 854: pyspark Documentation

pyspark Documentation, Release master

Table 305 – continued from previous pagegetOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default

value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setLossType(value) Sets the value of lossType.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxIter(value) Sets the value of maxIter.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of

minWeightFractionPerNode.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for Gradient Boosted Tree Regression.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setStepSize(value) Sets the value of stepSize.setSubsamplingRate(value) Sets the value of subsamplingRate.setValidationIndicatorCol(value) Sets the value of validationIndicatorCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

850 Chapter 3. API Reference

Page 855: pyspark Documentation

pyspark Documentation, Release master

Attributes

cacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpuritylabelColleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodeparams Returns all params ordered by name.predictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypesvalidationIndicatorColvalidationTolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

3.3. ML 851

Page 856: pyspark Documentation

pyspark Documentation, Release master

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

852 Chapter 3. API Reference

Page 857: pyspark Documentation

pyspark Documentation, Release master

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getLossType()Gets the value of lossType or its default value.

New in version 1.4.0.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getStepSize()Gets the value of stepSize or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.

getValidationTol()Gets the value of validationTol or its default value.

New in version 3.0.0.

getWeightCol()Gets the value of weightCol or its default value.

3.3. ML 853

Page 858: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCacheNodeIds(value)Sets the value of cacheNodeIds.

New in version 1.4.0.

setCheckpointInterval(value)Sets the value of checkpointInterval.

New in version 1.4.0.

setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setImpurity(value)Sets the value of impurity .

New in version 1.4.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setLossType(value)Sets the value of lossType.

New in version 1.4.0.

setMaxBins(value)Sets the value of maxBins.

New in version 1.4.0.

854 Chapter 3. API Reference

Page 859: pyspark Documentation

pyspark Documentation, Release master

setMaxDepth(value)Sets the value of maxDepth.

New in version 1.4.0.

setMaxIter(value)Sets the value of maxIter.

New in version 1.4.0.

setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.

New in version 1.4.0.

setMinInfoGain(value)Sets the value of minInfoGain.

New in version 1.4.0.

setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.

New in version 1.4.0.

setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.

New in version 3.0.0.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10,lossType="squared", maxIter=20, stepSize=0.1, seed=None, impurity="variance", fea-tureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None, leafCol="",minWeightFractionPerNode=0.0, weightCol=None)

Sets params for Gradient Boosted Tree Regression.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

New in version 1.4.0.

setStepSize(value)Sets the value of stepSize.

New in version 1.4.0.

setSubsamplingRate(value)Sets the value of subsamplingRate.

New in version 1.4.0.

setValidationIndicatorCol(value)Sets the value of validationIndicatorCol.

New in version 3.0.0.

3.3. ML 855

Page 860: pyspark Documentation

pyspark Documentation, Release master

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: squared, absolute')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['variance']

supportedLossTypes = ['squared', 'absolute']

validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')

validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

856 Chapter 3. API Reference

Page 861: pyspark Documentation

pyspark Documentation, Release master

GBTRegressionModel

class pyspark.ml.regression.GBTRegressionModel(java_model=None)Model fitted by GBTRegressor.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluateEachIteration(dataset, loss) Method to compute error or loss for every iterationof gradient boosting.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getLossType() Gets the value of lossType or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxIter() Gets the value of maxIter or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.

continues on next page

3.3. ML 857

Page 862: pyspark Documentation

pyspark Documentation, Release master

Table 307 – continued from previous pagegetStepSize() Gets the value of stepSize or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getValidationIndicatorCol() Gets the value of validationIndicatorCol or its default

value.getValidationTol() Gets the value of validationTol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the

feature vector.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

cacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.impuritylabelColleafCollossTypemaxBinsmaxDepthmaxItermaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNode

continues on next page

858 Chapter 3. API Reference

Page 863: pyspark Documentation

pyspark Documentation, Release master

Table 308 – continued from previous pagenumFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColseedstepSizesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiessupportedLossTypestoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the

ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.validationIndicatorColvalidationTolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluateEachIteration(dataset, loss)Method to compute error or loss for every iteration of gradient boosting.

New in version 2.4.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on, where datasetis an instance of pyspark.sql.DataFrame

loss [str] The loss function used to compute error. Supported options: squared, absolute

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra

3.3. ML 859

Page 864: pyspark Documentation

pyspark Documentation, Release master

values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getLossType()Gets the value of lossType or its default value.

New in version 1.4.0.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

860 Chapter 3. API Reference

Page 865: pyspark Documentation

pyspark Documentation, Release master

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getStepSize()Gets the value of stepSize or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getValidationIndicatorCol()Gets the value of validationIndicatorCol or its default value.

getValidationTol()Gets the value of validationTol or its default value.

New in version 3.0.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

3.3. ML 861

Page 866: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureImportancesEstimate of the importance of each feature.

Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.

New in version 2.0.0.

Examples

DecisionTreeRegressionModel.featureImportances

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

getNumTreesNumber of trees in ensemble.

New in version 2.0.0.

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

lossType = Param(parent='undefined', name='lossType', doc='Loss function which GBT tries to minimize (case-insensitive). Supported options: squared, absolute')

862 Chapter 3. API Reference

Page 867: pyspark Documentation

pyspark Documentation, Release master

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['variance']

supportedLossTypes = ['squared', 'absolute']

toDebugStringFull description of model.

New in version 2.0.0.

totalNumNodesTotal number of nodes, summed over all trees in the ensemble.

New in version 2.0.0.

treeWeightsReturn the weights for each tree

New in version 1.5.0.

treesTrees in this ensemble. Warning: These have null parent Estimators.

New in version 2.0.0.

validationIndicatorCol = Param(parent='undefined', name='validationIndicatorCol', doc='name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.')

validationTol = Param(parent='undefined', name='validationTol', doc='Threshold for stopping early when fit with validation is used. If the error rate on the validation input changes by less than the validationTol, then learning will stop early (before `maxIter`). This parameter is ignored when fit without validation is used.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 863

Page 868: pyspark Documentation

pyspark Documentation, Release master

GeneralizedLinearRegression

class pyspark.ml.regression.GeneralizedLinearRegression(*args, **kwargs)Generalized Linear Regression.

Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function)and a description of the error distribution (family). It supports “gaussian”, “binomial”, “poisson”, “gamma” and“tweedie” as family. Valid link functions for each family is listed below. The first link function of each familyis the default one.

• “gaussian” -> “identity”, “log”, “inverse”

• “binomial” -> “logit”, “probit”, “cloglog”

• “poisson” -> “log”, “identity”, “sqrt”

• “gamma” -> “inverse”, “identity”, “log”

• “tweedie” -> power link function specified through “linkPower”. The default link power in the tweediefamily is 1 - variancePower.

New in version 2.0.0.

Notes

For more information see Wikipedia page on GLM

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(0.0, 0.0)),... (1.0, Vectors.dense(1.0, 2.0)),... (2.0, Vectors.dense(0.0, 0.0)),... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity",→˓linkPredictionCol="p")>>> glr.setRegParam(0.1)GeneralizedLinearRegression...>>> glr.getRegParam()0.1>>> glr.clear(glr.regParam)>>> glr.setMaxIter(10)GeneralizedLinearRegression...>>> glr.getMaxIter()10>>> glr.clear(glr.maxIter)>>> model = glr.fit(df)>>> model.setFeaturesCol("features")GeneralizedLinearRegressionModel...>>> model.getMaxIter()25>>> model.getAggregationDepth()2>>> transformed = model.transform(df)>>> abs(transformed.head().prediction - 1.5) < 0.001True

(continues on next page)

864 Chapter 3. API Reference

Page 869: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> abs(transformed.head().p - 1.5) < 0.001True>>> model.coefficientsDenseVector([1.5..., -1.0...])>>> model.numFeatures2>>> abs(model.intercept - 1.5) < 0.001True>>> glr_path = temp_path + "/glr">>> glr.save(glr_path)>>> glr2 = GeneralizedLinearRegression.load(glr_path)>>> glr.getFamily() == glr2.getFamily()True>>> model_path = temp_path + "/glr_model">>> model.save(model_path)>>> model2 = GeneralizedLinearRegressionModel.load(model_path)>>> model.intercept == model2.interceptTrue>>> model.coefficients[0] == model2.coefficients[0]True>>> model.transform(df).take(1) == model2.transform(df).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLink() Gets the value of link or its default value.

continues on next page

3.3. ML 865

Page 870: pyspark Documentation

pyspark Documentation, Release master

Table 309 – continued from previous pagegetLinkPower() Gets the value of linkPower or its default value.getLinkPredictionCol() Gets the value of linkPredictionCol or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOffsetCol() Gets the value of offsetCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getTol() Gets the value of tol or its default value.getVariancePower() Gets the value of variancePower or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setFamily(value) Sets the value of family .setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setLink(value) Sets the value of link.setLinkPower(value) Sets the value of linkPower.setLinkPredictionCol(value) Sets the value of linkPredictionCol.setMaxIter(value) Sets the value of maxIter.setOffsetCol(value) Sets the value of offsetCol.setParams(self, \*[, labelCol, featuresCol, . . . ]) Sets params for generalized linear regression.setPredictionCol(value) Sets the value of predictionCol.setRegParam(value) Sets the value of regParam.setSolver(value) Sets the value of solver.setTol(value) Sets the value of tol.setVariancePower(value) Sets the value of variancePower.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

866 Chapter 3. API Reference

Page 871: pyspark Documentation

pyspark Documentation, Release master

Attributes

aggregationDepthfamilyfeaturesColfitInterceptlabelCollinklinkPowerlinkPredictionColmaxIteroffsetColparams Returns all params ordered by name.predictionColregParamsolvertolvariancePowerweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

3.3. ML 867

Page 872: pyspark Documentation

pyspark Documentation, Release master

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getFamily()Gets the value of family or its default value.

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getLink()Gets the value of link or its default value.

New in version 2.0.0.

getLinkPower()Gets the value of linkPower or its default value.

New in version 2.2.0.

getLinkPredictionCol()Gets the value of linkPredictionCol or its default value.

New in version 2.0.0.

868 Chapter 3. API Reference

Page 873: pyspark Documentation

pyspark Documentation, Release master

getMaxIter()Gets the value of maxIter or its default value.

getOffsetCol()Gets the value of offsetCol or its default value.

New in version 2.3.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSolver()Gets the value of solver or its default value.

getTol()Gets the value of tol or its default value.

getVariancePower()Gets the value of variancePower or its default value.

New in version 2.2.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAggregationDepth(value)Sets the value of aggregationDepth.

New in version 3.0.0.

3.3. ML 869

Page 874: pyspark Documentation

pyspark Documentation, Release master

setFamily(value)Sets the value of family .

New in version 2.0.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setFitIntercept(value)Sets the value of fitIntercept.

New in version 2.0.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLink(value)Sets the value of link.

New in version 2.0.0.

setLinkPower(value)Sets the value of linkPower.

New in version 2.2.0.

setLinkPredictionCol(value)Sets the value of linkPredictionCol.

New in version 2.0.0.

setMaxIter(value)Sets the value of maxIter.

New in version 2.0.0.

setOffsetCol(value)Sets the value of offsetCol.

New in version 2.3.0.

setParams(self, \*, labelCol="label", featuresCol="features", predictionCol="prediction", fam-ily="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0,weightCol=None, solver="irls", linkPredictionCol=None, variancePower=0.0,linkPower=None, offsetCol=None, aggregationDepth=2)

Sets params for generalized linear regression.

New in version 2.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setRegParam(value)Sets the value of regParam.

New in version 2.0.0.

setSolver(value)Sets the value of solver.

New in version 2.0.0.

870 Chapter 3. API Reference

Page 875: pyspark Documentation

pyspark Documentation, Release master

setTol(value)Sets the value of tol.

New in version 2.0.0.

setVariancePower(value)Sets the value of variancePower.

New in version 2.2.0.

setWeightCol(value)Sets the value of weightCol.

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

family = Param(parent='undefined', name='family', doc='The name of family which is a description of the error distribution to be used in the model. Supported options: gaussian (default), binomial, poisson, gamma and tweedie.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

link = Param(parent='undefined', name='link', doc='The name of link function which provides the relationship between the linear predictor and the mean of the distribution function. Supported options: identity, log, inverse, logit, probit, cloglog and sqrt.')

linkPower = Param(parent='undefined', name='linkPower', doc='The index in the power link function. Only applicable to the Tweedie family.')

linkPredictionCol = Param(parent='undefined', name='linkPredictionCol', doc='link prediction (linear predictor) column name')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

offsetCol = Param(parent='undefined', name='offsetCol', doc='The offset column name. If this is not set or empty, we treat all instance offsets as 0.0')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: irls.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

variancePower = Param(parent='undefined', name='variancePower', doc='The power in the variance function of the Tweedie distribution which characterizes the relationship between the variance and mean of the distribution. Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 871

Page 876: pyspark Documentation

pyspark Documentation, Release master

GeneralizedLinearRegressionModel

class pyspark.ml.regression.GeneralizedLinearRegressionModel(java_model=None)Model fitted by GeneralizedLinearRegression.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getFamily() Gets the value of family or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLink() Gets the value of link or its default value.getLinkPower() Gets the value of linkPower or its default value.getLinkPredictionCol() Gets the value of linkPredictionCol or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOffsetCol() Gets the value of offsetCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getTol() Gets the value of tol or its default value.getVariancePower() Gets the value of variancePower or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.continues on next page

872 Chapter 3. API Reference

Page 877: pyspark Documentation

pyspark Documentation, Release master

Table 311 – continued from previous pageisSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLinkPredictionCol(value) Sets the value of linkPredictionCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthcoefficients Model coefficients.familyfeaturesColfitIntercepthasSummary Indicates whether a training summary exists for this

model instance.intercept Model intercept.labelCollinklinkPowerlinkPredictionColmaxIternumFeatures Returns the number of features the model was trained

on.offsetColparams Returns all params ordered by name.predictionColregParamsolversummary Gets summary (e.g.tolvariancePowerweightCol

3.3. ML 873

Page 878: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 2.0.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on, where datasetis an instance of pyspark.sql.DataFrame

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getFamily()Gets the value of family or its default value.

New in version 2.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

874 Chapter 3. API Reference

Page 879: pyspark Documentation

pyspark Documentation, Release master

getLink()Gets the value of link or its default value.

New in version 2.0.0.

getLinkPower()Gets the value of linkPower or its default value.

New in version 2.2.0.

getLinkPredictionCol()Gets the value of linkPredictionCol or its default value.

New in version 2.0.0.

getMaxIter()Gets the value of maxIter or its default value.

getOffsetCol()Gets the value of offsetCol or its default value.

New in version 2.3.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSolver()Gets the value of solver or its default value.

getTol()Gets the value of tol or its default value.

getVariancePower()Gets the value of variancePower or its default value.

New in version 2.2.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

3.3. ML 875

Page 880: pyspark Documentation

pyspark Documentation, Release master

predict(value)Predict label for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLinkPredictionCol(value)Sets the value of linkPredictionCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

coefficientsModel coefficients.

New in version 2.0.0.

family = Param(parent='undefined', name='family', doc='The name of family which is a description of the error distribution to be used in the model. Supported options: gaussian (default), binomial, poisson, gamma and tweedie.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

876 Chapter 3. API Reference

Page 881: pyspark Documentation

pyspark Documentation, Release master

interceptModel intercept.

New in version 2.0.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

link = Param(parent='undefined', name='link', doc='The name of link function which provides the relationship between the linear predictor and the mean of the distribution function. Supported options: identity, log, inverse, logit, probit, cloglog and sqrt.')

linkPower = Param(parent='undefined', name='linkPower', doc='The index in the power link function. Only applicable to the Tweedie family.')

linkPredictionCol = Param(parent='undefined', name='linkPredictionCol', doc='link prediction (linear predictor) column name')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

offsetCol = Param(parent='undefined', name='offsetCol', doc='The offset column name. If this is not set or empty, we treat all instance offsets as 0.0')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: irls.')

summaryGets summary (e.g. residuals, deviance, pValues) of model on training set. An exception is thrown iftrainingSummary is None.

New in version 2.0.0.

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

variancePower = Param(parent='undefined', name='variancePower', doc='The power in the variance function of the Tweedie distribution which characterizes the relationship between the variance and mean of the distribution. Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

GeneralizedLinearRegressionSummary

class pyspark.ml.regression.GeneralizedLinearRegressionSummary(java_obj=None)Generalized linear regression results evaluated on a dataset.

New in version 2.0.0.

Methods

residuals([residualsType]) Get the residuals of the fitted model by type.

3.3. ML 877

Page 882: pyspark Documentation

pyspark Documentation, Release master

Attributes

aic Akaike’s “An Information Criterion”(AIC) for thefitted model.

degreesOfFreedom Degrees of freedom.deviance The deviance for the fitted model.dispersion The dispersion of the fitted model.nullDeviance The deviance for the null model.numInstances Number of instances in DataFrame predictions.predictionCol Field in predictions which gives the predicted

value of each instance.predictions Predictions output by the model’s transform method.rank The numeric rank of the fitted linear model.residualDegreeOfFreedom The residual degrees of freedom.residualDegreeOfFreedomNull The residual degrees of freedom for the null model.

Methods Documentation

residuals(residualsType='deviance')Get the residuals of the fitted model by type.

New in version 2.0.0.

Parameters

residualsType [str, optional] The type of residuals which should be returned. Supportedoptions: deviance (default), pearson, working, and response.

Attributes Documentation

aicAkaike’s “An Information Criterion”(AIC) for the fitted model.

New in version 2.0.0.

degreesOfFreedomDegrees of freedom.

New in version 2.0.0.

devianceThe deviance for the fitted model.

New in version 2.0.0.

dispersionThe dispersion of the fitted model. It is taken as 1.0 for the “binomial” and “poisson” families, andotherwise estimated by the residual Pearson’s Chi-Squared statistic (which is defined as sum of the squaresof the Pearson residuals) divided by the residual degrees of freedom.

New in version 2.0.0.

nullDevianceThe deviance for the null model.

New in version 2.0.0.

878 Chapter 3. API Reference

Page 883: pyspark Documentation

pyspark Documentation, Release master

numInstancesNumber of instances in DataFrame predictions.

New in version 2.2.0.

predictionColField in predictions which gives the predicted value of each instance. This is set to a new columnname if the original model’s predictionCol is not set.

New in version 2.0.0.

predictionsPredictions output by the model’s transform method.

New in version 2.0.0.

rankThe numeric rank of the fitted linear model.

New in version 2.0.0.

residualDegreeOfFreedomThe residual degrees of freedom.

New in version 2.0.0.

residualDegreeOfFreedomNullThe residual degrees of freedom for the null model.

New in version 2.0.0.

GeneralizedLinearRegressionTrainingSummary

class pyspark.ml.regression.GeneralizedLinearRegressionTrainingSummary(java_obj=None)Generalized linear regression training results.

New in version 2.0.0.

Methods

residuals([residualsType]) Get the residuals of the fitted model by type.

Attributes

aic Akaike’s “An Information Criterion”(AIC) for thefitted model.

coefficientStandardErrors Standard error of estimated coefficients and inter-cept.

degreesOfFreedom Degrees of freedom.deviance The deviance for the fitted model.dispersion The dispersion of the fitted model.nullDeviance The deviance for the null model.numInstances Number of instances in DataFrame predictions.numIterations Number of training iterations.

continues on next page

3.3. ML 879

Page 884: pyspark Documentation

pyspark Documentation, Release master

Table 316 – continued from previous pagepValues Two-sided p-value of estimated coefficients and in-

tercept.predictionCol Field in predictions which gives the predicted

value of each instance.predictions Predictions output by the model’s transform method.rank The numeric rank of the fitted linear model.residualDegreeOfFreedom The residual degrees of freedom.residualDegreeOfFreedomNull The residual degrees of freedom for the null model.solver The numeric solver used for training.tValues T-statistic of estimated coefficients and intercept.

Methods Documentation

residuals(residualsType='deviance')Get the residuals of the fitted model by type.

New in version 2.0.0.

Parameters

residualsType [str, optional] The type of residuals which should be returned. Supportedoptions: deviance (default), pearson, working, and response.

Attributes Documentation

aicAkaike’s “An Information Criterion”(AIC) for the fitted model.

New in version 2.0.0.

coefficientStandardErrorsStandard error of estimated coefficients and intercept.

If GeneralizedLinearRegression.fitIntercept is set to True, then the last element returnedcorresponds to the intercept.

New in version 2.0.0.

degreesOfFreedomDegrees of freedom.

New in version 2.0.0.

devianceThe deviance for the fitted model.

New in version 2.0.0.

dispersionThe dispersion of the fitted model. It is taken as 1.0 for the “binomial” and “poisson” families, andotherwise estimated by the residual Pearson’s Chi-Squared statistic (which is defined as sum of the squaresof the Pearson residuals) divided by the residual degrees of freedom.

New in version 2.0.0.

nullDevianceThe deviance for the null model.

New in version 2.0.0.

880 Chapter 3. API Reference

Page 885: pyspark Documentation

pyspark Documentation, Release master

numInstancesNumber of instances in DataFrame predictions.

New in version 2.2.0.

numIterationsNumber of training iterations.

New in version 2.0.0.

pValuesTwo-sided p-value of estimated coefficients and intercept.

If GeneralizedLinearRegression.fitIntercept is set to True, then the last element returnedcorresponds to the intercept.

New in version 2.0.0.

predictionColField in predictions which gives the predicted value of each instance. This is set to a new columnname if the original model’s predictionCol is not set.

New in version 2.0.0.

predictionsPredictions output by the model’s transform method.

New in version 2.0.0.

rankThe numeric rank of the fitted linear model.

New in version 2.0.0.

residualDegreeOfFreedomThe residual degrees of freedom.

New in version 2.0.0.

residualDegreeOfFreedomNullThe residual degrees of freedom for the null model.

New in version 2.0.0.

solverThe numeric solver used for training.

New in version 2.0.0.

tValuesT-statistic of estimated coefficients and intercept.

If GeneralizedLinearRegression.fitIntercept is set to True, then the last element returnedcorresponds to the intercept.

New in version 2.0.0.

3.3. ML 881

Page 886: pyspark Documentation

pyspark Documentation, Release master

IsotonicRegression

class pyspark.ml.regression.IsotonicRegression(*args, **kwargs)Currently implemented using parallelized pool adjacent violators algorithm. Only univariate (single feature)algorithm supported.

New in version 1.6.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> ir = IsotonicRegression()>>> model = ir.fit(df)>>> model.setFeaturesCol("features")IsotonicRegressionModel...>>> model.numFeatures1>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.transform(test0).head().prediction0.0>>> model.predict(test0.head().features[model.getFeatureIndex()])0.0>>> model.boundariesDenseVector([0.0, 1.0])>>> ir_path = temp_path + "/ir">>> ir.save(ir_path)>>> ir2 = IsotonicRegression.load(ir_path)>>> ir2.getIsotonic()True>>> model_path = temp_path + "/ir_model">>> model.save(model_path)>>> model2 = IsotonicRegressionModel.load(model_path)>>> model.boundaries == model2.boundariesTrue>>> model.predictions == model2.predictionsTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

882 Chapter 3. API Reference

Page 887: pyspark Documentation

pyspark Documentation, Release master

Table 317 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFeatureIndex() Gets the value of featureIndex or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getIsotonic() Gets the value of isotonic or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeatureIndex(value) Sets the value of featureIndex.setFeaturesCol(value) Sets the value of featuresCol.setIsotonic(value) Sets the value of isotonic.setLabelCol(value) Sets the value of labelCol.setParams(*args, **kwargs) setParams(self, *, featuresCol=”features”, label-

Col=”label”, predictionCol=”prediction”, weight-Col=None, isotonic=True, featureIndex=0): Set theparams for IsotonicRegression.

setPredictionCol(value) Sets the value of predictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

3.3. ML 883

Page 888: pyspark Documentation

pyspark Documentation, Release master

Attributes

featureIndexfeaturesColisotoniclabelColparams Returns all params ordered by name.predictionColweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

884 Chapter 3. API Reference

Page 889: pyspark Documentation

pyspark Documentation, Release master

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFeatureIndex()Gets the value of featureIndex or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getIsotonic()Gets the value of isotonic or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

3.3. ML 885

Page 890: pyspark Documentation

pyspark Documentation, Release master

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeatureIndex(value)Sets the value of featureIndex.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 1.6.0.

setIsotonic(value)Sets the value of isotonic.

setLabelCol(value)Sets the value of labelCol.

New in version 1.6.0.

setParams(*args, **kwargs)setParams(self, *, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, weight-Col=None, isotonic=True, featureIndex=0): Set the params for IsotonicRegression.

setPredictionCol(value)Sets the value of predictionCol.

New in version 1.6.0.

setWeightCol(value)Sets the value of weightCol.

New in version 1.6.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

featureIndex = Param(parent='undefined', name='featureIndex', doc='The index of the feature if featuresCol is a vector column, no effect otherwise.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

isotonic = Param(parent='undefined', name='isotonic', doc='whether the output sequence should be isotonic/increasing (true) orantitonic/decreasing (false).')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

886 Chapter 3. API Reference

Page 891: pyspark Documentation

pyspark Documentation, Release master

IsotonicRegressionModel

class pyspark.ml.regression.IsotonicRegressionModel(java_model=None)Model fitted by IsotonicRegression.

New in version 1.6.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFeatureIndex() Gets the value of featureIndex or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getIsotonic() Gets the value of isotonic or its default value.getLabelCol() Gets the value of labelCol or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeatureIndex(value) Sets the value of featureIndex.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.continues on next page

3.3. ML 887

Page 892: pyspark Documentation

pyspark Documentation, Release master

Table 319 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.

Attributes

boundaries Boundaries in increasing order for which predictionsare known.

featureIndexfeaturesColisotoniclabelColnumFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColpredictions Predictions associated with the boundaries at the

same index, monotone because of isotonic regres-sion.

weightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

888 Chapter 3. API Reference

Page 893: pyspark Documentation

pyspark Documentation, Release master

getFeatureIndex()Gets the value of featureIndex or its default value.

getFeaturesCol()Gets the value of featuresCol or its default value.

getIsotonic()Gets the value of isotonic or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeatureIndex(value)Sets the value of featureIndex.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

3.3. ML 889

Page 894: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

boundariesBoundaries in increasing order for which predictions are known.

New in version 1.6.0.

featureIndex = Param(parent='undefined', name='featureIndex', doc='The index of the feature if featuresCol is a vector column, no effect otherwise.')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

isotonic = Param(parent='undefined', name='isotonic', doc='whether the output sequence should be isotonic/increasing (true) orantitonic/decreasing (false).')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 3.0.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

predictionsPredictions associated with the boundaries at the same index, monotone because of isotonic regression.

New in version 1.6.0.

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

890 Chapter 3. API Reference

Page 895: pyspark Documentation

pyspark Documentation, Release master

LinearRegression

class pyspark.ml.regression.LinearRegression(*args, **kwargs)Linear regression.

The learning objective is to minimize the specified loss function, with regularization. This supports two kindsof loss:

• squaredError (a.k.a squared loss)

• huber (a hybrid of squared error for relatively small errors and absolute error for relatively large ones, andwe estimate the scale parameter from training data)

This supports multiple types of regularization:

• none (a.k.a. ordinary least squares)

• L2 (ridge regression)

• L1 (Lasso)

• L2 + L1 (elastic net)

New in version 1.4.0.

Notes

Fitting with huber loss only supports none and L2 regularization.

Examples

>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, 2.0, Vectors.dense(1.0)),... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])>>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight")>>> lr.setMaxIter(5)LinearRegression...>>> lr.getMaxIter()5>>> lr.setRegParam(0.1)LinearRegression...>>> lr.getRegParam()0.1>>> lr.setRegParam(0.0)LinearRegression...>>> model = lr.fit(df)>>> model.setFeaturesCol("features")LinearRegressionModel...>>> model.setPredictionCol("newPrediction")LinearRegressionModel...>>> model.getMaxIter()5>>> model.getMaxBlockSizeInMB()0.0>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001True

(continues on next page)

3.3. ML 891

Page 896: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001True>>> abs(model.coefficients[0] - 1.0) < 0.001True>>> abs(model.intercept - 0.0) < 0.001True>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> abs(model.transform(test1).head().newPrediction - 1.0) < 0.001True>>> lr.setParams("vector")Traceback (most recent call last):

...TypeError: Method setParams forces keyword arguments.>>> lr_path = temp_path + "/lr">>> lr.save(lr_path)>>> lr2 = LinearRegression.load(lr_path)>>> lr2.getMaxIter()5>>> model_path = temp_path + "/lr_model">>> model.save(model_path)>>> model2 = LinearRegressionModel.load(model_path)>>> model.coefficients[0] == model2.coefficients[0]True>>> model.intercept == model2.interceptTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True>>> model.numFeatures1>>> model.write().format("pmml").save(model_path + "_2")

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

continues on next page

892 Chapter 3. API Reference

Page 897: pyspark Documentation

pyspark Documentation, Release master

Table 321 – continued from previous pagefitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map

in paramMaps.getAggregationDepth() Gets the value of aggregationDepth or its default

value.getElasticNetParam() Gets the value of elasticNetParam or its default

value.getEpsilon() Gets the value of epsilon or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.getLoss() Gets the value of loss or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getStandardization() Gets the value of standardization or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setAggregationDepth(value) Sets the value of aggregationDepth.setElasticNetParam(value) Sets the value of elasticNetParam.setEpsilon(value) Sets the value of epsilon.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setLabelCol(value) Sets the value of labelCol.setLoss(value) Sets the value of loss.setMaxBlockSizeInMB(value) Sets the value of maxBlockSizeInMB.setMaxIter(value) Sets the value of maxIter.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for linear regression.setPredictionCol(value) Sets the value of predictionCol.setRegParam(value) Sets the value of regParam.setSolver(value) Sets the value of solver.setStandardization(value) Sets the value of standardization.setTol(value) Sets the value of tol.setWeightCol(value) Sets the value of weightCol.

continues on next page

3.3. ML 893

Page 898: pyspark Documentation

pyspark Documentation, Release master

Table 321 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.

Attributes

aggregationDepthelasticNetParamepsilonfeaturesColfitInterceptlabelCollossmaxBlockSizeInMBmaxIterparams Returns all params ordered by name.predictionColregParamsolverstandardizationtolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

894 Chapter 3. API Reference

Page 899: pyspark Documentation

pyspark Documentation, Release master

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getElasticNetParam()Gets the value of elasticNetParam or its default value.

getEpsilon()Gets the value of epsilon or its default value.

New in version 2.3.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getLoss()Gets the value of loss or its default value.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

3.3. ML 895

Page 900: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSolver()Gets the value of solver or its default value.

getStandardization()Gets the value of standardization or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setAggregationDepth(value)Sets the value of aggregationDepth.

setElasticNetParam(value)Sets the value of elasticNetParam.

setEpsilon(value)Sets the value of epsilon.

New in version 2.3.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

896 Chapter 3. API Reference

Page 901: pyspark Documentation

pyspark Documentation, Release master

setFitIntercept(value)Sets the value of fitIntercept.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLoss(value)Sets the value of loss.

setMaxBlockSizeInMB(value)Sets the value of maxBlockSizeInMB.

New in version 3.1.0.

setMaxIter(value)Sets the value of maxIter.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", max-Iter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardiza-tion=True, solver="auto", weightCol=None, aggregationDepth=2, loss="squaredError",epsilon=1.35, maxBlockSizeInMB=0.0)

Sets params for linear regression.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setRegParam(value)Sets the value of regParam.

setSolver(value)Sets the value of solver.

setStandardization(value)Sets the value of standardization.

setTol(value)Sets the value of tol.

setWeightCol(value)Sets the value of weightCol.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')

epsilon = Param(parent='undefined', name='epsilon', doc='The shape parameter to control the amount of robustness. Must be > 1.0. Only valid when loss is huber')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

3.3. ML 897

Page 902: pyspark Documentation

pyspark Documentation, Release master

loss = Param(parent='undefined', name='loss', doc='The loss function to be optimized. Supported options: squaredError, huber.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: auto, normal, l-bfgs.')

standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

LinearRegressionModel

class pyspark.ml.regression.LinearRegressionModel(java_model=None)Model fitted by LinearRegression.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset) Evaluates the model on a test dataset.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getAggregationDepth() Gets the value of aggregationDepth or its defaultvalue.

getElasticNetParam() Gets the value of elasticNetParam or its defaultvalue.

getEpsilon() Gets the value of epsilon or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getLabelCol() Gets the value of labelCol or its default value.

continues on next page

898 Chapter 3. API Reference

Page 903: pyspark Documentation

pyspark Documentation, Release master

Table 323 – continued from previous pagegetLoss() Gets the value of loss or its default value.getMaxBlockSizeInMB() Gets the value of maxBlockSizeInMB or its default

value.getMaxIter() Gets the value of maxIter or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSolver() Gets the value of solver or its default value.getStandardization() Gets the value of standardization or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an GeneralMLWriter instance for this ML

instance.

Attributes

aggregationDepthcoefficients Model coefficients.elasticNetParamepsilonfeaturesColfitIntercepthasSummary Indicates whether a training summary exists for this

model instance.intercept Model intercept.labelCollossmaxBlockSizeInMBmaxIter

continues on next page

3.3. ML 899

Page 904: pyspark Documentation

pyspark Documentation, Release master

Table 324 – continued from previous pagenumFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColregParamscale The value by which ‖𝑦−𝑋 ′𝑤‖ is scaled down when

loss is “huber”, otherwise 1.0.solverstandardizationsummary Gets summary (e.g.tolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset)Evaluates the model on a test dataset.

New in version 2.0.0.

Parameters

dataset [pyspark.sql.DataFrame] Test dataset to evaluate model on, where datasetis an instance of pyspark.sql.DataFrame

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

900 Chapter 3. API Reference

Page 905: pyspark Documentation

pyspark Documentation, Release master

getAggregationDepth()Gets the value of aggregationDepth or its default value.

getElasticNetParam()Gets the value of elasticNetParam or its default value.

getEpsilon()Gets the value of epsilon or its default value.

New in version 2.3.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getLabelCol()Gets the value of labelCol or its default value.

getLoss()Gets the value of loss or its default value.

getMaxBlockSizeInMB()Gets the value of maxBlockSizeInMB or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSolver()Gets the value of solver or its default value.

getStandardization()Gets the value of standardization or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

3.3. ML 901

Page 906: pyspark Documentation

pyspark Documentation, Release master

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an GeneralMLWriter instance for this ML instance.

Attributes Documentation

aggregationDepth = Param(parent='undefined', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).')

coefficientsModel coefficients.

New in version 2.0.0.

elasticNetParam = Param(parent='undefined', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.')

epsilon = Param(parent='undefined', name='epsilon', doc='The shape parameter to control the amount of robustness. Must be > 1.0. Only valid when loss is huber')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

902 Chapter 3. API Reference

Page 907: pyspark Documentation

pyspark Documentation, Release master

hasSummaryIndicates whether a training summary exists for this model instance.

New in version 2.1.0.

interceptModel intercept.

New in version 1.4.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

loss = Param(parent='undefined', name='loss', doc='The loss function to be optimized. Supported options: squaredError, huber.')

maxBlockSizeInMB = Param(parent='undefined', name='maxBlockSizeInMB', doc='maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

scaleThe value by which ‖𝑦 −𝑋 ′𝑤‖ is scaled down when loss is “huber”, otherwise 1.0.

New in version 2.3.0.

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: auto, normal, l-bfgs.')

standardization = Param(parent='undefined', name='standardization', doc='whether to standardize the training features before fitting the model.')

summaryGets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is thrown iftrainingSummary is None.

New in version 2.0.0.

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

LinearRegressionSummary

class pyspark.ml.regression.LinearRegressionSummary(java_obj=None)Linear regression results evaluated on a dataset.

New in version 2.0.0.

3.3. ML 903

Page 908: pyspark Documentation

pyspark Documentation, Release master

Attributes

coefficientStandardErrors Standard error of estimated coefficients and inter-cept.

degreesOfFreedom Degrees of freedom.devianceResiduals The weighted residuals, the usual residuals rescaled

by the square root of the instance weights.explainedVariance Returns the explained variance regression score.featuresCol Field in “predictions” which gives the features of

each instance as a vector.labelCol Field in “predictions” which gives the true label of

each instance.meanAbsoluteError Returns the mean absolute error, which is a risk func-

tion corresponding to the expected value of the abso-lute error loss or l1-norm loss.

meanSquaredError Returns the mean squared error, which is a riskfunction corresponding to the expected value of thesquared error loss or quadratic loss.

numInstances Number of instances in DataFrame predictionspValues Two-sided p-value of estimated coefficients and in-

tercept.predictionCol Field in “predictions” which gives the predicted

value of the label at each instance.predictions Dataframe outputted by the model’s transform

method.r2 Returns R^2, the coefficient of determination.r2adj Returns Adjusted R^2, the adjusted coefficient of de-

termination.residuals Residuals (label - predicted value)rootMeanSquaredError Returns the root mean squared error, which is defined

as the square root of the mean squared error.tValues T-statistic of estimated coefficients and intercept.

Attributes Documentation

coefficientStandardErrorsStandard error of estimated coefficients and intercept. This value is only available when using the “normal”solver.

If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.

New in version 2.0.0.

See also:

LinearRegression.solver

degreesOfFreedomDegrees of freedom.

New in version 2.2.0.

904 Chapter 3. API Reference

Page 909: pyspark Documentation

pyspark Documentation, Release master

devianceResidualsThe weighted residuals, the usual residuals rescaled by the square root of the instance weights.

New in version 2.0.0.

explainedVarianceReturns the explained variance regression score. explainedVariance = 1− 𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦−𝑦)

𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦)

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

For additional information see Explained variation on Wikipedia

New in version 2.0.0.

featuresColField in “predictions” which gives the features of each instance as a vector.

New in version 2.0.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 2.0.0.

meanAbsoluteErrorReturns the mean absolute error, which is a risk function corresponding to the expected value of the abso-lute error loss or l1-norm loss.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

New in version 2.0.0.

meanSquaredErrorReturns the mean squared error, which is a risk function corresponding to the expected value of the squarederror loss or quadratic loss.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

New in version 2.0.0.

numInstancesNumber of instances in DataFrame predictions

New in version 2.0.0.

pValuesTwo-sided p-value of estimated coefficients and intercept. This value is only available when using the“normal” solver.

If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.

3.3. ML 905

Page 910: pyspark Documentation

pyspark Documentation, Release master

New in version 2.0.0.

See also:

LinearRegression.solver

predictionColField in “predictions” which gives the predicted value of the label at each instance.

New in version 2.0.0.

predictionsDataframe outputted by the model’s transform method.

New in version 2.0.0.

r2Returns R^2, the coefficient of determination.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

See also Wikipedia coefficient of determination

New in version 2.0.0.

r2adjReturns Adjusted R^2, the adjusted coefficient of determination.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

Wikipedia coefficient of determination, Adjusted R^2

New in version 2.4.0.

residualsResiduals (label - predicted value)

New in version 2.0.0.

rootMeanSquaredErrorReturns the root mean squared error, which is defined as the square root of the mean squared error.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

New in version 2.0.0.

tValuesT-statistic of estimated coefficients and intercept. This value is only available when using the “normal”solver.

906 Chapter 3. API Reference

Page 911: pyspark Documentation

pyspark Documentation, Release master

If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.

New in version 2.0.0.

See also:

LinearRegression.solver

LinearRegressionTrainingSummary

class pyspark.ml.regression.LinearRegressionTrainingSummary(java_obj=None)Linear regression training results. Currently, the training summary ignores the training weights except for theobjective trace.

New in version 2.0.0.

Attributes

coefficientStandardErrors Standard error of estimated coefficients and inter-cept.

degreesOfFreedom Degrees of freedom.devianceResiduals The weighted residuals, the usual residuals rescaled

by the square root of the instance weights.explainedVariance Returns the explained variance regression score.featuresCol Field in “predictions” which gives the features of

each instance as a vector.labelCol Field in “predictions” which gives the true label of

each instance.meanAbsoluteError Returns the mean absolute error, which is a risk func-

tion corresponding to the expected value of the abso-lute error loss or l1-norm loss.

meanSquaredError Returns the mean squared error, which is a riskfunction corresponding to the expected value of thesquared error loss or quadratic loss.

numInstances Number of instances in DataFrame predictionsobjectiveHistory Objective function (scaled loss + regularization) at

each iteration.pValues Two-sided p-value of estimated coefficients and in-

tercept.predictionCol Field in “predictions” which gives the predicted

value of the label at each instance.predictions Dataframe outputted by the model’s transform

method.r2 Returns R^2, the coefficient of determination.r2adj Returns Adjusted R^2, the adjusted coefficient of de-

termination.residuals Residuals (label - predicted value)rootMeanSquaredError Returns the root mean squared error, which is defined

as the square root of the mean squared error.tValues T-statistic of estimated coefficients and intercept.totalIterations Number of training iterations until termination.

3.3. ML 907

Page 912: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

coefficientStandardErrorsStandard error of estimated coefficients and intercept. This value is only available when using the “normal”solver.

If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.

New in version 2.0.0.

See also:

LinearRegression.solver

degreesOfFreedomDegrees of freedom.

New in version 2.2.0.

devianceResidualsThe weighted residuals, the usual residuals rescaled by the square root of the instance weights.

New in version 2.0.0.

explainedVarianceReturns the explained variance regression score. explainedVariance = 1− 𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦−𝑦)

𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒(𝑦)

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

For additional information see Explained variation on Wikipedia

New in version 2.0.0.

featuresColField in “predictions” which gives the features of each instance as a vector.

New in version 2.0.0.

labelColField in “predictions” which gives the true label of each instance.

New in version 2.0.0.

meanAbsoluteErrorReturns the mean absolute error, which is a risk function corresponding to the expected value of the abso-lute error loss or l1-norm loss.

908 Chapter 3. API Reference

Page 913: pyspark Documentation

pyspark Documentation, Release master

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

New in version 2.0.0.

meanSquaredErrorReturns the mean squared error, which is a risk function corresponding to the expected value of the squarederror loss or quadratic loss.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

New in version 2.0.0.

numInstancesNumber of instances in DataFrame predictions

New in version 2.0.0.

objectiveHistoryObjective function (scaled loss + regularization) at each iteration. This value is only available when usingthe “l-bfgs” solver.

New in version 2.0.0.

See also:

LinearRegression.solver

pValuesTwo-sided p-value of estimated coefficients and intercept. This value is only available when using the“normal” solver.

If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.

New in version 2.0.0.

See also:

LinearRegression.solver

predictionColField in “predictions” which gives the predicted value of the label at each instance.

New in version 2.0.0.

predictionsDataframe outputted by the model’s transform method.

New in version 2.0.0.

r2Returns R^2, the coefficient of determination.

3.3. ML 909

Page 914: pyspark Documentation

pyspark Documentation, Release master

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

See also Wikipedia coefficient of determination

New in version 2.0.0.

r2adjReturns Adjusted R^2, the adjusted coefficient of determination.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

Wikipedia coefficient of determination, Adjusted R^2

New in version 2.4.0.

residualsResiduals (label - predicted value)

New in version 2.0.0.

rootMeanSquaredErrorReturns the root mean squared error, which is defined as the square root of the mean squared error.

Notes

This ignores instance weights (setting all to 1.0) from LinearRegression.weightCol. This will change inlater Spark versions.

New in version 2.0.0.

tValuesT-statistic of estimated coefficients and intercept. This value is only available when using the “normal”solver.

If LinearRegression.fitIntercept is set to True, then the last element returned corresponds tothe intercept.

New in version 2.0.0.

See also:

LinearRegression.solver

totalIterationsNumber of training iterations until termination. This value is only available when using the “l-bfgs” solver.

New in version 2.0.0.

See also:

LinearRegression.solver

910 Chapter 3. API Reference

Page 915: pyspark Documentation

pyspark Documentation, Release master

RandomForestRegressor

class pyspark.ml.regression.RandomForestRegressor(*args, **kwargs)Random Forest learning algorithm for regression. It supports both continuous and categorical features.

New in version 1.4.0.

Examples

>>> from numpy import allclose>>> from pyspark.ml.linalg import Vectors>>> df = spark.createDataFrame([... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)>>> rf.getMinWeightFractionPerNode()0.0>>> rf.setSeed(42)RandomForestRegressor...>>> model = rf.fit(df)>>> model.getBootstrap()True>>> model.getSeed()42>>> model.setLeafCol("leafId")RandomForestRegressionModel...>>> model.featureImportancesSparseVector(1, {0: 1.0})>>> allclose(model.treeWeights, [1.0, 1.0])True>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])>>> model.predict(test0.head().features)0.0>>> model.predictLeaf(test0.head().features)DenseVector([0.0, 0.0])>>> result = model.transform(test0).head()>>> result.prediction0.0>>> result.leafIdDenseVector([0.0, 0.0])>>> model.numFeatures1>>> model.trees[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]>>> model.getNumTrees2>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features→˓"])>>> model.transform(test1).head().prediction0.5>>> rfr_path = temp_path + "/rfr">>> rf.save(rfr_path)>>> rf2 = RandomForestRegressor.load(rfr_path)>>> rf2.getNumTrees()2>>> model_path = temp_path + "/rfr_model">>> model.save(model_path)

(continues on next page)

3.3. ML 911

Page 916: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> model2 = RandomForestRegressionModel.load(model_path)>>> model.featureImportances == model2.featureImportancesTrue>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getNumTrees() Gets the value of numTrees or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.continues on next page

912 Chapter 3. API Reference

Page 917: pyspark Documentation

pyspark Documentation, Release master

Table 327 – continued from previous pagegetParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.getSubsamplingRate() Gets the value of subsamplingRate or its default

value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBootstrap(value) Sets the value of bootstrap.setCacheNodeIds(value) Sets the value of cacheNodeIds.setCheckpointInterval(value) Sets the value of checkpointInterval.setFeatureSubsetStrategy(value) Sets the value of featureSubsetStrategy .setFeaturesCol(value) Sets the value of featuresCol.setImpurity(value) Sets the value of impurity .setLabelCol(value) Sets the value of labelCol.setLeafCol(value) Sets the value of leafCol.setMaxBins(value) Sets the value of maxBins.setMaxDepth(value) Sets the value of maxDepth.setMaxMemoryInMB(value) Sets the value of maxMemoryInMB.setMinInfoGain(value) Sets the value of minInfoGain.setMinInstancesPerNode(value) Sets the value of minInstancesPerNode.setMinWeightFractionPerNode(value) Sets the value of

minWeightFractionPerNode.setNumTrees(value) Sets the value of numTrees.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets params for linear regression.setPredictionCol(value) Sets the value of predictionCol.setSeed(value) Sets the value of seed.setSubsamplingRate(value) Sets the value of subsamplingRate.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

bootstrapcacheNodeIdscheckpointIntervalfeatureSubsetStrategyfeaturesColimpurity

continues on next page

3.3. ML 913

Page 918: pyspark Documentation

pyspark Documentation, Release master

Table 328 – continued from previous pagelabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumTreesparams Returns all params ordered by name.predictionColseedsubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiesweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

914 Chapter 3. API Reference

Page 919: pyspark Documentation

pyspark Documentation, Release master

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getBootstrap()Gets the value of bootstrap or its default value.

New in version 3.0.0.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

3.3. ML 915

Page 920: pyspark Documentation

pyspark Documentation, Release master

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getNumTrees()Gets the value of numTrees or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBootstrap(value)Sets the value of bootstrap.

916 Chapter 3. API Reference

Page 921: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

setCacheNodeIds(value)Sets the value of cacheNodeIds.

setCheckpointInterval(value)Sets the value of checkpointInterval.

setFeatureSubsetStrategy(value)Sets the value of featureSubsetStrategy .

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setImpurity(value)Sets the value of impurity .

New in version 1.4.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setMaxBins(value)Sets the value of maxBins.

setMaxDepth(value)Sets the value of maxDepth.

setMaxMemoryInMB(value)Sets the value of maxMemoryInMB.

setMinInfoGain(value)Sets the value of minInfoGain.

setMinInstancesPerNode(value)Sets the value of minInstancesPerNode.

setMinWeightFractionPerNode(value)Sets the value of minWeightFractionPerNode.

New in version 3.0.0.

setNumTrees(value)Sets the value of numTrees.

New in version 1.4.0.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction",maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMem-oryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",subsamplingRate=1.0, seed=None, numTrees=20, featureSubsetStrategy="auto", leaf-Col="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)

Sets params for linear regression.

New in version 1.4.0.

3.3. ML 917

Page 922: pyspark Documentation

pyspark Documentation, Release master

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

setSubsamplingRate(value)Sets the value of subsamplingRate.

New in version 1.4.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['variance']

918 Chapter 3. API Reference

Page 923: pyspark Documentation

pyspark Documentation, Release master

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

RandomForestRegressionModel

class pyspark.ml.regression.RandomForestRegressionModel(java_model=None)Model fitted by RandomForestRegressor.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBootstrap() Gets the value of bootstrap or its default value.getCacheNodeIds() Gets the value of cacheNodeIds or its default value.getCheckpointInterval() Gets the value of checkpointInterval or its default

value.getFeatureSubsetStrategy() Gets the value of featureSubsetStrategy or its default

value.getFeaturesCol() Gets the value of featuresCol or its default value.getImpurity() Gets the value of impurity or its default value.getLabelCol() Gets the value of labelCol or its default value.getLeafCol() Gets the value of leafCol or its default value.getMaxBins() Gets the value of maxBins or its default value.getMaxDepth() Gets the value of maxDepth or its default value.getMaxMemoryInMB() Gets the value of maxMemoryInMB or its default

value.getMinInfoGain() Gets the value of minInfoGain or its default value.getMinInstancesPerNode() Gets the value of minInstancesPerNode or its default

value.getMinWeightFractionPerNode() Gets the value of minWeightFractionPerNode or its

default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getSeed() Gets the value of seed or its default value.

continues on next page

3.3. ML 919

Page 924: pyspark Documentation

pyspark Documentation, Release master

Table 329 – continued from previous pagegetSubsamplingRate() Gets the value of subsamplingRate or its default

value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.predictLeaf(value) Predict the indices of the leaves corresponding to the

feature vector.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setLeafCol(value) Sets the value of leafCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

bootstrapcacheNodeIdscheckpointIntervalfeatureImportances Estimate of the importance of each feature.featureSubsetStrategyfeaturesColgetNumTrees Number of trees in ensemble.impuritylabelColleafColmaxBinsmaxDepthmaxMemoryInMBminInfoGainminInstancesPerNodeminWeightFractionPerNodenumFeatures Returns the number of features the model was trained

on.numTreesparams Returns all params ordered by name.predictionColseed

continues on next page

920 Chapter 3. API Reference

Page 925: pyspark Documentation

pyspark Documentation, Release master

Table 330 – continued from previous pagesubsamplingRatesupportedFeatureSubsetStrategiessupportedImpuritiestoDebugString Full description of model.totalNumNodes Total number of nodes, summed over all trees in the

ensemble.treeWeights Return the weights for each treetrees Trees in this ensemble.weightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBootstrap()Gets the value of bootstrap or its default value.

New in version 3.0.0.

getCacheNodeIds()Gets the value of cacheNodeIds or its default value.

getCheckpointInterval()Gets the value of checkpointInterval or its default value.

3.3. ML 921

Page 926: pyspark Documentation

pyspark Documentation, Release master

getFeatureSubsetStrategy()Gets the value of featureSubsetStrategy or its default value.

New in version 1.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getImpurity()Gets the value of impurity or its default value.

New in version 1.4.0.

getLabelCol()Gets the value of labelCol or its default value.

getLeafCol()Gets the value of leafCol or its default value.

getMaxBins()Gets the value of maxBins or its default value.

getMaxDepth()Gets the value of maxDepth or its default value.

getMaxMemoryInMB()Gets the value of maxMemoryInMB or its default value.

getMinInfoGain()Gets the value of minInfoGain or its default value.

getMinInstancesPerNode()Gets the value of minInstancesPerNode or its default value.

getMinWeightFractionPerNode()Gets the value of minWeightFractionPerNode or its default value.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getSeed()Gets the value of seed or its default value.

getSubsamplingRate()Gets the value of subsamplingRate or its default value.

New in version 1.4.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

922 Chapter 3. API Reference

Page 927: pyspark Documentation

pyspark Documentation, Release master

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

predictLeaf(value)Predict the indices of the leaves corresponding to the feature vector.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setLeafCol(value)Sets the value of leafCol.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

3.3. ML 923

Page 928: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

bootstrap = Param(parent='undefined', name='bootstrap', doc='Whether bootstrap samples are used when building trees.')

cacheNodeIds = Param(parent='undefined', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.')

checkpointInterval = Param(parent='undefined', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.')

featureImportancesEstimate of the importance of each feature.

Each feature’s importance is the average of its importance across all trees in the ensemble The importancevector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman.“The Elements of Statistical Learning, 2nd Edition.” 2001.) and follows the implementation from scikit-learn.

New in version 2.0.0.

Examples

DecisionTreeRegressionModel.featureImportances

featureSubsetStrategy = Param(parent='undefined', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supported options: 'auto' (choose automatically for task: If numTrees == 1, set to 'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to 'onethird' for regression), 'all' (use all features), 'onethird' (use 1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use log2(number of features)), 'n' (when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features). default = 'auto'")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

getNumTreesNumber of trees in ensemble.

New in version 2.0.0.

impurity = Param(parent='undefined', name='impurity', doc='Criterion used for information gain calculation (case-insensitive). Supported options: variance')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

leafCol = Param(parent='undefined', name='leafCol', doc='Leaf indices column name. Predicted leaf index of each instance in each tree by preorder.')

maxBins = Param(parent='undefined', name='maxBins', doc='Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.')

maxDepth = Param(parent='undefined', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.')

maxMemoryInMB = Param(parent='undefined', name='maxMemoryInMB', doc='Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.')

minInfoGain = Param(parent='undefined', name='minInfoGain', doc='Minimum information gain for a split to be considered at a tree node.')

minInstancesPerNode = Param(parent='undefined', name='minInstancesPerNode', doc='Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.')

minWeightFractionPerNode = Param(parent='undefined', name='minWeightFractionPerNode', doc='Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in interval [0.0, 0.5).')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

numTrees = Param(parent='undefined', name='numTrees', doc='Number of trees to train (>= 1).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

seed = Param(parent='undefined', name='seed', doc='random seed.')

subsamplingRate = Param(parent='undefined', name='subsamplingRate', doc='Fraction of the training data used for learning each decision tree, in range (0, 1].')

924 Chapter 3. API Reference

Page 929: pyspark Documentation

pyspark Documentation, Release master

supportedFeatureSubsetStrategies = ['auto', 'all', 'onethird', 'sqrt', 'log2']

supportedImpurities = ['variance']

toDebugStringFull description of model.

New in version 2.0.0.

totalNumNodesTotal number of nodes, summed over all trees in the ensemble.

New in version 2.0.0.

treeWeightsReturn the weights for each tree

New in version 1.5.0.

treesTrees in this ensemble. Warning: These have null parent Estimators.

New in version 2.0.0.

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

FMRegressor

class pyspark.ml.regression.FMRegressor(*args, **kwargs)Factorization Machines learning algorithm for regression.

solver Supports:

• gd (normal mini-batch gradient descent)

• adamW (default)

New in version 3.0.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.regression import FMRegressor>>> df = spark.createDataFrame([... (2.0, Vectors.dense(2.0)),... (1.0, Vectors.dense(1.0)),... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])>>>>>> fm = FMRegressor(factorSize=2)>>> fm.setSeed(16)FMRegressor...>>> model = fm.fit(df)>>> model.getMaxIter()100>>> test0 = spark.createDataFrame([... (Vectors.dense(-2.0),),... (Vectors.dense(0.5),),... (Vectors.dense(1.0),),... (Vectors.dense(4.0),)], ["features"])>>> model.transform(test0).show(10, False)

(continues on next page)

3.3. ML 925

Page 930: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

+--------+-------------------+|features|prediction |+--------+-------------------+|[-2.0] |-1.9989237712341565||[0.5] |0.4956682219523814 ||[1.0] |0.994586620589689 ||[4.0] |3.9880970124135344 |+--------+-------------------+...>>> model.intercept-0.0032501766849261557>>> model.linearDenseVector([0.9978])>>> model.factorsDenseMatrix(1, 2, [0.0173, 0.0021], 1)>>> model_path = temp_path + "/fm_model">>> model.save(model_path)>>> model2 = FMRegressionModel.load(model_path)>>> model2.intercept-0.0032501766849261557>>> model2.linearDenseVector([0.9978])>>> model2.factorsDenseMatrix(1, 2, [0.0173, 0.0021], 1)>>> model.transform(test0).take(1) == model2.transform(test0).take(1)True

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.

continues on next page

926 Chapter 3. API Reference

Page 931: pyspark Documentation

pyspark Documentation, Release master

Table 331 – continued from previous pagegetInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFactorSize(value) Sets the value of factorSize.setFeaturesCol(value) Sets the value of featuresCol.setFitIntercept(value) Sets the value of fitIntercept.setFitLinear(value) Sets the value of fitLinear.setInitStd(value) Sets the value of initStd.setLabelCol(value) Sets the value of labelCol.setMaxIter(value) Sets the value of maxIter.setMiniBatchFraction(value) Sets the value of miniBatchFraction.setParams(self, \*[, featuresCol, labelCol, . . . ]) Sets Params for FMRegressor.setPredictionCol(value) Sets the value of predictionCol.setRegParam(value) Sets the value of regParam.setSeed(value) Sets the value of seed.setSolver(value) Sets the value of solver.setStepSize(value) Sets the value of stepSize.setTol(value) Sets the value of tol.write() Returns an MLWriter instance for this ML instance.

3.3. ML 927

Page 932: pyspark Documentation

pyspark Documentation, Release master

Attributes

factorSizefeaturesColfitInterceptfitLinearinitStdlabelColmaxIterminiBatchFractionparams Returns all params ordered by name.predictionColregParamseedsolverstepSizetolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

928 Chapter 3. API Reference

Page 933: pyspark Documentation

pyspark Documentation, Release master

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getFactorSize()Gets the value of factorSize or its default value.

New in version 3.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getFitLinear()Gets the value of fitLinear or its default value.

New in version 3.0.0.

getInitStd()Gets the value of initStd or its default value.

New in version 3.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.

New in version 3.0.0.

3.3. ML 929

Page 934: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSeed()Gets the value of seed or its default value.

getSolver()Gets the value of solver or its default value.

getStepSize()Gets the value of stepSize or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFactorSize(value)Sets the value of factorSize.

New in version 3.0.0.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setFitIntercept(value)Sets the value of fitIntercept.

930 Chapter 3. API Reference

Page 935: pyspark Documentation

pyspark Documentation, Release master

New in version 3.0.0.

setFitLinear(value)Sets the value of fitLinear.

New in version 3.0.0.

setInitStd(value)Sets the value of initStd.

New in version 3.0.0.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setMaxIter(value)Sets the value of maxIter.

New in version 3.0.0.

setMiniBatchFraction(value)Sets the value of miniBatchFraction.

New in version 3.0.0.

setParams(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", factor-Size=8, fitIntercept=True, fitLinear=True, regParam=0.0, miniBatchFraction=1.0, init-Std=0.01, maxIter=100, stepSize=1.0, tol=1e-6, solver="adamW", seed=None)

Sets Params for FMRegressor.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

setRegParam(value)Sets the value of regParam.

New in version 3.0.0.

setSeed(value)Sets the value of seed.

New in version 3.0.0.

setSolver(value)Sets the value of solver.

New in version 3.0.0.

setStepSize(value)Sets the value of stepSize.

New in version 3.0.0.

setTol(value)Sets the value of tol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

3.3. ML 931

Page 936: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')

initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

seed = Param(parent='undefined', name='seed', doc='random seed.')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

FMRegressionModel

class pyspark.ml.regression.FMRegressionModel(java_model=None)Model fitted by FMRegressor.

New in version 3.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

932 Chapter 3. API Reference

Page 937: pyspark Documentation

pyspark Documentation, Release master

Table 333 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getFactorSize() Gets the value of factorSize or its default value.getFeaturesCol() Gets the value of featuresCol or its default value.getFitIntercept() Gets the value of fitIntercept or its default value.getFitLinear() Gets the value of fitLinear or its default value.getInitStd() Gets the value of initStd or its default value.getLabelCol() Gets the value of labelCol or its default value.getMaxIter() Gets the value of maxIter or its default value.getMiniBatchFraction() Gets the value of miniBatchFraction or its default

value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getRegParam() Gets the value of regParam or its default value.getSeed() Gets the value of seed or its default value.getSolver() Gets the value of solver or its default value.getStepSize() Gets the value of stepSize or its default value.getTol() Gets the value of tol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).predict(value) Predict label for the given features.read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setFeaturesCol(value) Sets the value of featuresCol.setPredictionCol(value) Sets the value of predictionCol.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

3.3. ML 933

Page 938: pyspark Documentation

pyspark Documentation, Release master

Attributes

factorSizefactors Model factor term.featuresColfitInterceptfitLinearinitStdintercept Model intercept.labelCollinear Model linear term.maxIterminiBatchFractionnumFeatures Returns the number of features the model was trained

on.params Returns all params ordered by name.predictionColregParamseedsolverstepSizetolweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

934 Chapter 3. API Reference

Page 939: pyspark Documentation

pyspark Documentation, Release master

extra [dict, optional] extra param values

Returns

dict merged param map

getFactorSize()Gets the value of factorSize or its default value.

New in version 3.0.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getFitIntercept()Gets the value of fitIntercept or its default value.

getFitLinear()Gets the value of fitLinear or its default value.

New in version 3.0.0.

getInitStd()Gets the value of initStd or its default value.

New in version 3.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getMaxIter()Gets the value of maxIter or its default value.

getMiniBatchFraction()Gets the value of miniBatchFraction or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getRegParam()Gets the value of regParam or its default value.

getSeed()Gets the value of seed or its default value.

getSolver()Gets the value of solver or its default value.

getStepSize()Gets the value of stepSize or its default value.

getTol()Gets the value of tol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

3.3. ML 935

Page 940: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

predict(value)Predict label for the given features.

New in version 3.0.0.

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setFeaturesCol(value)Sets the value of featuresCol.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

936 Chapter 3. API Reference

Page 941: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

factorSize = Param(parent='undefined', name='factorSize', doc='Dimensionality of the factor vectors, which are used to get pairwise interactions between variables')

factorsModel factor term.

New in version 3.0.0.

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

fitIntercept = Param(parent='undefined', name='fitIntercept', doc='whether to fit an intercept term.')

fitLinear = Param(parent='undefined', name='fitLinear', doc='whether to fit linear term (aka 1-way term)')

initStd = Param(parent='undefined', name='initStd', doc='standard deviation of initial coefficients')

interceptModel intercept.

New in version 3.0.0.

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

linearModel linear term.

New in version 3.0.0.

maxIter = Param(parent='undefined', name='maxIter', doc='max number of iterations (>= 0).')

miniBatchFraction = Param(parent='undefined', name='miniBatchFraction', doc='fraction of the input data set that should be used for one iteration of gradient descent')

numFeaturesReturns the number of features the model was trained on. If unknown, returns -1

New in version 2.1.0.

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

regParam = Param(parent='undefined', name='regParam', doc='regularization parameter (>= 0).')

seed = Param(parent='undefined', name='seed', doc='random seed.')

solver = Param(parent='undefined', name='solver', doc='The solver algorithm for optimization. Supported options: gd, adamW. (Default adamW)')

stepSize = Param(parent='undefined', name='stepSize', doc='Step size to be used for each iteration of optimization (>= 0).')

tol = Param(parent='undefined', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 937

Page 942: pyspark Documentation

pyspark Documentation, Release master

3.3.10 Statistics

ANOVATest() Conduct ANOVA Classification Test for continuous fea-tures against categorical labels.

ChiSquareTest() Conduct Pearson’s independence test for every featureagainst the label.

Correlation() Compute the correlation matrix for the input dataset ofVectors using the specified method.

FValueTest() Conduct F Regression test for continuous featuresagainst continuous labels.

KolmogorovSmirnovTest() Conduct the two-sided Kolmogorov Smirnov (KS) testfor data sampled from a continuous distribution.

MultivariateGaussian(mean, cov) Represents a (mean, cov) tupleSummarizer() Tools for vectorized statistics on MLlib Vectors.SummaryBuilder(jSummaryBuilder) A builder object that provides summary statistics about

a given column.

ANOVATest

class pyspark.ml.stat.ANOVATestConduct ANOVA Classification Test for continuous features against categorical labels.

New in version 3.1.0.

Methods

test(dataset, featuresCol, labelCol[, flatten]) Perform an ANOVA test using dataset.

Methods Documentation

static test(dataset, featuresCol, labelCol, flatten=False)Perform an ANOVA test using dataset.

New in version 3.1.0.

Parameters

dataset [pyspark.sql.DataFrame] DataFrame of categorical labels and continuousfeatures.

featuresCol [str] Name of features column in dataset, of type Vector (VectorUDT).

labelCol [str] Name of label column in dataset, of any numerical type.

flatten [bool, optional] if True, flattens the returned dataframe.

Returns

pyspark.sql.DataFrame DataFrame containing the test result for every featureagainst the label. If flatten is True, this DataFrame will contain one row per feature withthe following fields:

• featureIndex: int

• pValue: float

938 Chapter 3. API Reference

Page 943: pyspark Documentation

pyspark Documentation, Release master

• degreesOfFreedom: int

• fValue: float

If flatten is False, this DataFrame will contain a single Row with the following fields:

• pValues: Vector

• degreesOfFreedom: Array[int]

• fValues: Vector

Each of these fields has one value per feature.

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.stat import ANOVATest>>> dataset = [[2.0, Vectors.dense([0.43486404, 0.57153633, 0.43175686,... 0.51418671, 0.61632374, 0.96565515])],... [1.0, Vectors.dense([0.49162732, 0.6785187, 0.85460572,... 0.59784822, 0.12394819, 0.53783355])],... [2.0, Vectors.dense([0.30879653, 0.54904515, 0.17103889,... 0.40492506, 0.18957493, 0.5440016])],... [3.0, Vectors.dense([0.68114391, 0.60549825, 0.69094651,... 0.62102109, 0.05471483, 0.96449167])]]>>> dataset = spark.createDataFrame(dataset, ["label", "features"])>>> anovaResult = ANOVATest.test(dataset, 'features', 'label')>>> row = anovaResult.select("fValues", "pValues").collect()>>> row[0].fValuesDenseVector([4.0264, 18.4713, 3.4659, 1.9042, 0.5532, 0.512])>>> row[0].pValuesDenseVector([0.3324, 0.1623, 0.3551, 0.456, 0.689, 0.7029])>>> anovaResult = ANOVATest.test(dataset, 'features', 'label', True)>>> row = anovaResult.orderBy("featureIndex").collect()>>> row[0].fValue4.026438671875297

ChiSquareTest

class pyspark.ml.stat.ChiSquareTestConduct Pearson’s independence test for every feature against the label. For each feature, the (feature, label)pairs are converted into a contingency matrix for which the Chi-squared statistic is computed. All label andfeature values must be categorical.

The null hypothesis is that the occurrence of the outcomes is statistically independent.

New in version 2.2.0.

3.3. ML 939

Page 944: pyspark Documentation

pyspark Documentation, Release master

Methods

test(dataset, featuresCol, labelCol[, flatten]) Perform a Pearson’s independence test using dataset.

Methods Documentation

static test(dataset, featuresCol, labelCol, flatten=False)Perform a Pearson’s independence test using dataset.

New in version 2.2.0.

Changed in version 3.1.0: Added optional flatten argument.

Parameters

dataset [pyspark.sql.DataFrame] DataFrame of categorical labels and categoricalfeatures. Real-valued features will be treated as categorical for each distinct value.

featuresCol [str] Name of features column in dataset, of type Vector (VectorUDT).

labelCol [str] Name of label column in dataset, of any numerical type.

flatten [bool, optional] if True, flattens the returned dataframe.

Returns

pyspark.sql.DataFrame DataFrame containing the test result for every featureagainst the label. If flatten is True, this DataFrame will contain one row per feature withthe following fields:

• featureIndex: int

• pValue: float

• degreesOfFreedom: int

• statistic: float

If flatten is False, this DataFrame will contain a single Row with the following fields:

• pValues: Vector

• degreesOfFreedom: Array[int]

• statistics: Vector

Each of these fields has one value per feature.

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.stat import ChiSquareTest>>> dataset = [[0, Vectors.dense([0, 0, 1])],... [0, Vectors.dense([1, 0, 1])],... [1, Vectors.dense([2, 1, 1])],... [1, Vectors.dense([3, 1, 1])]]>>> dataset = spark.createDataFrame(dataset, ["label", "features"])>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')>>> chiSqResult.select("degreesOfFreedom").collect()[0]Row(degreesOfFreedom=[3, 1, 0])

(continues on next page)

940 Chapter 3. API Reference

Page 945: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label', True)>>> row = chiSqResult.orderBy("featureIndex").collect()>>> row[0].statistic4.0

Correlation

class pyspark.ml.stat.CorrelationCompute the correlation matrix for the input dataset of Vectors using the specified method. Methods currentlysupported: pearson (default), spearman.

New in version 2.2.0.

Notes

For Spearman, a rank correlation, we need to create an RDD[Double] for each column and sort it in order toretrieve the ranks and then join the columns back into an RDD[Vector], which is fairly costly. Cache the inputDataset before calling corr with method = ‘spearman’ to avoid recomputing the common lineage.

Methods

corr(dataset, column[, method]) Compute the correlation matrix with specifiedmethod using dataset.

Methods Documentation

static corr(dataset, column, method='pearson')Compute the correlation matrix with specified method using dataset.

New in version 2.2.0.

Parameters

dataset [pyspark.sql.DataFrame] A DataFrame.

column [str] The name of the column of vectors for which the correlation coefficient needsto be computed. This must be a column of the dataset, and it must contain Vector objects.

method [str, optional] String specifying the method to use for computing correlation. Sup-ported: pearson (default), spearman.

Returns

A DataFrame that contains the correlation matrix of the column of vectors. This

DataFrame contains a single row and a single column of name METHODNAME(COLUMN).

3.3. ML 941

Page 946: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.linalg import DenseMatrix, Vectors>>> from pyspark.ml.stat import Correlation>>> dataset = [[Vectors.dense([1, 0, 0, -2])],... [Vectors.dense([4, 5, 0, 3])],... [Vectors.dense([6, 7, 0, 8])],... [Vectors.dense([9, 0, 0, 1])]]>>> dataset = spark.createDataFrame(dataset, ['features'])>>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').→˓collect()[0][0]>>> print(str(pearsonCorr).replace('nan', 'NaN'))DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],

[ 0.0556..., 1. , NaN, 0.9135...],[ NaN, NaN, 1. , NaN],[ 0.4004..., 0.9135..., NaN, 1. ]])

>>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').→˓collect()[0][0]>>> print(str(spearmanCorr).replace('nan', 'NaN'))DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],

[ 0.1054..., 1. , NaN, 0.9486... ],[ NaN, NaN, 1. , NaN],[ 0.4 , 0.9486... , NaN, 1. ]])

FValueTest

class pyspark.ml.stat.FValueTestConduct F Regression test for continuous features against continuous labels.

New in version 3.1.0.

Methods

test(dataset, featuresCol, labelCol[, flatten]) Perform a F Regression test using dataset.

Methods Documentation

static test(dataset, featuresCol, labelCol, flatten=False)Perform a F Regression test using dataset.

New in version 3.1.0.

Parameters

dataset [pyspark.sql.DataFrame] DataFrame of continuous labels and continuousfeatures.

featuresCol [str] Name of features column in dataset, of type Vector (VectorUDT).

labelCol [str] Name of label column in dataset, of any numerical type.

flatten [bool, optional] if True, flattens the returned dataframe.

Returns

942 Chapter 3. API Reference

Page 947: pyspark Documentation

pyspark Documentation, Release master

pyspark.sql.DataFrame DataFrame containing the test result for every featureagainst the label. If flatten is True, this DataFrame will contain one row per feature withthe following fields:

• featureIndex: int

• pValue: float

• degreesOfFreedom: int

• fValue: float

If flatten is False, this DataFrame will contain a single Row with the following fields:

• pValues: Vector

• degreesOfFreedom: Array[int]

• fValues: Vector

Each of these fields has one value per feature.

Examples

>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.stat import FValueTest>>> dataset = [[0.57495218, Vectors.dense([0.43486404, 0.57153633, 0.43175686,... 0.51418671, 0.61632374, 0.→˓96565515])],... [0.84619853, Vectors.dense([0.49162732, 0.6785187, 0.85460572,... 0.59784822, 0.12394819, 0.→˓53783355])],... [0.39777647, Vectors.dense([0.30879653, 0.54904515, 0.17103889,... 0.40492506, 0.18957493, 0.→˓5440016])],... [0.79201573, Vectors.dense([0.68114391, 0.60549825, 0.69094651,... 0.62102109, 0.05471483, 0.→˓96449167])]]>>> dataset = spark.createDataFrame(dataset, ["label", "features"])>>> fValueResult = FValueTest.test(dataset, 'features', 'label')>>> row = fValueResult.select("fValues", "pValues").collect()>>> row[0].fValuesDenseVector([3.741, 7.5807, 142.0684, 34.9849, 0.4112, 0.0539])>>> row[0].pValuesDenseVector([0.1928, 0.1105, 0.007, 0.0274, 0.5871, 0.838])>>> fValueResult = FValueTest.test(dataset, 'features', 'label', True)>>> row = fValueResult.orderBy("featureIndex").collect()>>> row[0].fValue3.7409548308350593

3.3. ML 943

Page 948: pyspark Documentation

pyspark Documentation, Release master

KolmogorovSmirnovTest

class pyspark.ml.stat.KolmogorovSmirnovTestConduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous distribution.

By comparing the largest difference between the empirical cumulative distribution of the sample data and thetheoretical distribution we can provide a test for the the null hypothesis that the sample data comes from thattheoretical distribution.

New in version 2.4.0.

Methods

test(dataset, sampleCol, distName, *params) Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution equality.

Methods Documentation

static test(dataset, sampleCol, distName, *params)Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution equality. Currentlysupports the normal distribution, taking as parameters the mean and standard deviation.

New in version 2.4.0.

Parameters

dataset [pyspark.sql.DataFrame] a Dataset or a DataFrame containing the sampleof data to test.

sampleCol [str] Name of sample column in dataset, of any numerical type.

distName [str] a string name for a theoretical distribution, currently only support “norm”.

params [float] a list of float values specifying the parameters to be used for the theoreticaldistribution. For “norm” distribution, the parameters includes mean and variance.

Returns

A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data.

This DataFrame will contain a single Row with the following fields:

• pValue: Double

• statistic: Double

944 Chapter 3. API Reference

Page 949: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.stat import KolmogorovSmirnovTest>>> dataset = [[-1.0], [0.0], [1.0]]>>> dataset = spark.createDataFrame(dataset, ['sample'])>>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.→˓0).first()>>> round(ksResult.pValue, 3)1.0>>> round(ksResult.statistic, 3)0.175>>> dataset = [[2.0], [3.0], [4.0]]>>> dataset = spark.createDataFrame(dataset, ['sample'])>>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.→˓0).first()>>> round(ksResult.pValue, 3)1.0>>> round(ksResult.statistic, 3)0.175

MultivariateGaussian

class pyspark.ml.stat.MultivariateGaussian(mean, cov)Represents a (mean, cov) tuple

New in version 3.0.0.

Examples

>>> from pyspark.ml.linalg import DenseMatrix, Vectors>>> m = MultivariateGaussian(Vectors.dense([11,12]), DenseMatrix(2, 2, (1.0, 3.0,→˓5.0, 2.0)))>>> (m.mean, m.cov.toArray())(DenseVector([11.0, 12.0]), array([[ 1., 5.],

[ 3., 2.]]))

Summarizer

class pyspark.ml.stat.SummarizerTools for vectorized statistics on MLlib Vectors. The methods in this package provide various statistics forVectors contained inside DataFrames. This class lets users pick the statistics they would like to extract for agiven column.

New in version 2.4.0.

3.3. ML 945

Page 950: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.stat import Summarizer>>> from pyspark.sql import Row>>> from pyspark.ml.linalg import Vectors>>> summarizer = Summarizer.metrics("mean", "count")>>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).→˓toDF()>>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)+-----------------------------------+|aggregate_metrics(features, weight)|+-----------------------------------+|{[1.0,1.0,1.0], 1} |+-----------------------------------+

>>> df.select(summarizer.summary(df.features)).show(truncate=False)+--------------------------------+|aggregate_metrics(features, 1.0)|+--------------------------------+|{[1.0,1.5,2.0], 2} |+--------------------------------+

>>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)+--------------+|mean(features)|+--------------+|[1.0,1.0,1.0] |+--------------+

>>> df.select(Summarizer.mean(df.features)).show(truncate=False)+--------------+|mean(features)|+--------------+|[1.0,1.5,2.0] |+--------------+

Methods

count(col[, weightCol]) return a column of count summarymax(col[, weightCol]) return a column of max summarymean(col[, weightCol]) return a column of mean summarymetrics(*metrics) Given a list of metrics, provides a builder that it turns

computes metrics from a column.min(col[, weightCol]) return a column of min summarynormL1(col[, weightCol]) return a column of normL1 summarynormL2(col[, weightCol]) return a column of normL2 summarynumNonZeros(col[, weightCol]) return a column of numNonZero summarystd(col[, weightCol]) return a column of std summarysum(col[, weightCol]) return a column of sum summaryvariance(col[, weightCol]) return a column of variance summary

946 Chapter 3. API Reference

Page 951: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

static count(col, weightCol=None)return a column of count summary

New in version 2.4.0.

static max(col, weightCol=None)return a column of max summary

New in version 2.4.0.

static mean(col, weightCol=None)return a column of mean summary

New in version 2.4.0.

static metrics(*metrics)Given a list of metrics, provides a builder that it turns computes metrics from a column.

See the documentation of Summarizer for an example.

The following metrics are accepted (case sensitive):

• mean: a vector that contains the coefficient-wise mean.

• sum: a vector that contains the coefficient-wise sum.

• variance: a vector tha contains the coefficient-wise variance.

• std: a vector tha contains the coefficient-wise standard deviation.

• count: the count of all vectors seen.

• numNonzeros: a vector with the number of non-zeros for each coefficients

• max: the maximum for each coefficient.

• min: the minimum for each coefficient.

• normL2: the Euclidean norm for each coefficient.

• normL1: the L1 norm of each coefficient (sum of the absolute values).

New in version 2.4.0.

Returns

pyspark.ml.stat.SummaryBuilder

Notes

Currently, the performance of this interface is about 2x~3x slower than using the RDD interface.

3.3. ML 947

Page 952: pyspark Documentation

pyspark Documentation, Release master

Examples

metrics [str] metrics that can be provided.

static min(col, weightCol=None)return a column of min summary

New in version 2.4.0.

static normL1(col, weightCol=None)return a column of normL1 summary

New in version 2.4.0.

static normL2(col, weightCol=None)return a column of normL2 summary

New in version 2.4.0.

static numNonZeros(col, weightCol=None)return a column of numNonZero summary

New in version 2.4.0.

static std(col, weightCol=None)return a column of std summary

New in version 3.0.0.

static sum(col, weightCol=None)return a column of sum summary

New in version 3.0.0.

static variance(col, weightCol=None)return a column of variance summary

New in version 2.4.0.

SummaryBuilder

class pyspark.ml.stat.SummaryBuilder(jSummaryBuilder)A builder object that provides summary statistics about a given column.

Users should not directly create such builders, but instead use one of the methods in pyspark.ml.stat.Summarizer

New in version 2.4.0.

Methods

summary(featuresCol[, weightCol]) Returns an aggregate object that contains the sum-mary of the column with the requested metrics.

948 Chapter 3. API Reference

Page 953: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

summary(featuresCol, weightCol=None)Returns an aggregate object that contains the summary of the column with the requested metrics.

New in version 2.4.0.

Parameters

featuresCol [str] a column that contains features Vector object.

weightCol [str, optional] a column that contains weight value. Default weight is 1.0.

Returns

pyspark.sql.Column an aggregate column that contains the statistics. The exact con-tent of this structure is determined during the creation of the builder.

3.3.11 Tuning

ParamGridBuilder() Builder for a param grid used in grid search-basedmodel selection.

CrossValidator(*args, **kwargs) K-fold cross validation performs model selection bysplitting the dataset into a set of non-overlapping ran-domly partitioned folds which are used as separate train-ing and test datasets e.g., with k=3 folds, K-fold crossvalidation will generate 3 (training, test) dataset pairs,each of which uses 2/3 of the data for training and 1/3for testing.

CrossValidatorModel(bestModel[, avgMetrics,. . . ])

CrossValidatorModel contains the model with the high-est average cross-validation metric across folds and usesthis model to transform input data.

TrainValidationSplit(*args, **kwargs) Validation for hyper-parameter tuning.TrainValidationSplitModel(bestModel[, . . . ]) Model from train validation split.

ParamGridBuilder

class pyspark.ml.tuning.ParamGridBuilderBuilder for a param grid used in grid search-based model selection.

New in version 1.4.0.

Examples

>>> from pyspark.ml.classification import LogisticRegression>>> lr = LogisticRegression()>>> output = ParamGridBuilder() \... .baseOn({lr.labelCol: 'l'}) \... .baseOn([lr.predictionCol, 'p']) \... .addGrid(lr.regParam, [1.0, 2.0]) \... .addGrid(lr.maxIter, [1, 5]) \... .build()>>> expected = [... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓, (continues on next page)

3.3. ML 949

Page 954: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓,... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓,... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}→˓]>>> len(output) == len(expected)True>>> all([m in expected for m in output])True

Methods

addGrid(param, values) Sets the given parameters in this grid to fixed values.baseOn(*args) Sets the given parameters in this grid to fixed values.build() Builds and returns all combinations of parameters

specified by the param grid.

Methods Documentation

addGrid(param, values)Sets the given parameters in this grid to fixed values.

param must be an instance of Param associated with an instance of Params (such as Estimator or Trans-former).

New in version 1.4.0.

baseOn(*args)Sets the given parameters in this grid to fixed values. Accepts either a parameter dictionary or a list of(parameter, value) pairs.

New in version 1.4.0.

build()Builds and returns all combinations of parameters specified by the param grid.

New in version 1.4.0.

CrossValidator

class pyspark.ml.tuning.CrossValidator(*args, **kwargs)K-fold cross validation performs model selection by splitting the dataset into a set of non-overlapping randomlypartitioned folds which are used as separate training and test datasets e.g., with k=3 folds, K-fold cross validationwill generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.Each fold is used as the test set exactly once.

New in version 1.4.0.

950 Chapter 3. API Reference

Page 955: pyspark Documentation

pyspark Documentation, Release master

Examples

>>> from pyspark.ml.classification import LogisticRegression>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.tuning import CrossValidatorModel>>> import tempfile>>> dataset = spark.createDataFrame(... [(Vectors.dense([0.0]), 0.0),... (Vectors.dense([0.4]), 1.0),... (Vectors.dense([0.5]), 0.0),... (Vectors.dense([0.6]), 1.0),... (Vectors.dense([1.0]), 1.0)] * 10,... ["features", "label"])>>> lr = LogisticRegression()>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()>>> evaluator = BinaryClassificationEvaluator()>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid,→˓evaluator=evaluator,... parallelism=2)>>> cvModel = cv.fit(dataset)>>> cvModel.getNumFolds()3>>> cvModel.avgMetrics[0]0.5>>> path = tempfile.mkdtemp()>>> model_path = path + "/model">>> cvModel.write().save(model_path)>>> cvModelRead = CrossValidatorModel.read().load(model_path)>>> cvModelRead.avgMetrics[0.5, ...>>> evaluator.evaluate(cvModel.transform(dataset))0.8333...>>> evaluator.evaluate(cvModelRead.transform(dataset))0.8333...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

continues on next page

3.3. ML 951

Page 956: pyspark Documentation

pyspark Documentation, Release master

Table 345 – continued from previous pagefit(dataset[, params]) Fits a model to the input dataset with optional param-

eters.fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param map

in paramMaps.getCollectSubModels() Gets the value of collectSubModels or its default

value.getEstimator() Gets the value of estimator or its default value.getEstimatorParamMaps() Gets the value of estimatorParamMaps or its default

value.getEvaluator() Gets the value of evaluator or its default value.getFoldCol() Gets the value of foldCol or its default value.getNumFolds() Gets the value of numFolds or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParallelism() Gets the value of parallelism or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCollectSubModels(value) Sets the value of collectSubModels.setEstimator(value) Sets the value of estimator.setEstimatorParamMaps(value) Sets the value of estimatorParamMaps.setEvaluator(value) Sets the value of evaluator.setFoldCol(value) Sets the value of foldCol.setNumFolds(value) Sets the value of numFolds.setParallelism(value) Sets the value of parallelism.setParams(*args, **kwargs) setParams(self, *, estimator=None, estimator-

ParamMaps=None, evaluator=None, numFolds=3,seed=None, parallelism=1, collectSubMod-els=False, foldCol=””): Sets params for crossvalidator.

setSeed(value) Sets the value of seed.write() Returns an MLWriter instance for this ML instance.

952 Chapter 3. API Reference

Page 957: pyspark Documentation

pyspark Documentation, Release master

Attributes

collectSubModelsestimatorestimatorParamMapsevaluatorfoldColnumFoldsparallelismparams Returns all params ordered by name.seed

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies createsa deep copy of the embedded paramMap, and copies the embedded and extra parameters over.

New in version 1.4.0.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

CrossValidator Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

3.3. ML 953

Page 958: pyspark Documentation

pyspark Documentation, Release master

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCollectSubModels()Gets the value of collectSubModels or its default value.

getEstimator()Gets the value of estimator or its default value.

New in version 2.0.0.

getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.

New in version 2.0.0.

getEvaluator()Gets the value of evaluator or its default value.

New in version 2.0.0.

getFoldCol()Gets the value of foldCol or its default value.

New in version 3.1.0.

getNumFolds()Gets the value of numFolds or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParallelism()Gets the value of parallelism or its default value.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

954 Chapter 3. API Reference

Page 959: pyspark Documentation

pyspark Documentation, Release master

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

New in version 2.3.0.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCollectSubModels(value)Sets the value of collectSubModels.

setEstimator(value)Sets the value of estimator.

New in version 2.0.0.

setEstimatorParamMaps(value)Sets the value of estimatorParamMaps.

New in version 2.0.0.

setEvaluator(value)Sets the value of evaluator.

New in version 2.0.0.

setFoldCol(value)Sets the value of foldCol.

New in version 3.1.0.

setNumFolds(value)Sets the value of numFolds.

New in version 1.4.0.

setParallelism(value)Sets the value of parallelism.

setParams(*args, **kwargs)setParams(self, *, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,seed=None, parallelism=1, collectSubModels=False, foldCol=””): Sets params for cross validator.

New in version 1.4.0.

setSeed(value)Sets the value of seed.

3.3. ML 955

Page 960: pyspark Documentation

pyspark Documentation, Release master

write()Returns an MLWriter instance for this ML instance.

New in version 2.3.0.

Attributes Documentation

collectSubModels = Param(parent='undefined', name='collectSubModels', doc='Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.')

estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')

estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')

evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')

foldCol = Param(parent='undefined', name='foldCol', doc="Param for the column name of user specified fold number. Once this is specified, :py:class:`CrossValidator` won't do random k-fold split. Note that this column should be integer type with range [0, numFolds) and Spark will throw exception on out-of-range fold numbers.")

numFolds = Param(parent='undefined', name='numFolds', doc='number of folds for cross validation')

parallelism = Param(parent='undefined', name='parallelism', doc='the number of threads to use when running parallel algorithms (>= 1).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

CrossValidatorModel

class pyspark.ml.tuning.CrossValidatorModel(bestModel, avgMetrics=None, subMod-els=None)

CrossValidatorModel contains the model with the highest average cross-validation metric across folds and usesthis model to transform input data. CrossValidatorModel also tracks the metrics for each param map evaluated.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getEstimator() Gets the value of estimator or its default value.continues on next page

956 Chapter 3. API Reference

Page 961: pyspark Documentation

pyspark Documentation, Release master

Table 347 – continued from previous pagegetEstimatorParamMaps() Gets the value of estimatorParamMaps or its default

value.getEvaluator() Gets the value of evaluator or its default value.getFoldCol() Gets the value of foldCol or its default value.getNumFolds() Gets the value of numFolds or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

estimatorestimatorParamMapsevaluatorfoldColnumFoldsparams Returns all params ordered by name.seed

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies theunderlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded andextra parameters over. It does not copy the extra Params into the subModels.

New in version 1.4.0.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

3.3. ML 957

Page 962: pyspark Documentation

pyspark Documentation, Release master

CrossValidatorModel Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getEstimator()Gets the value of estimator or its default value.

New in version 2.0.0.

getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.

New in version 2.0.0.

getEvaluator()Gets the value of evaluator or its default value.

New in version 2.0.0.

getFoldCol()Gets the value of foldCol or its default value.

New in version 3.1.0.

getNumFolds()Gets the value of numFolds or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

958 Chapter 3. API Reference

Page 963: pyspark Documentation

pyspark Documentation, Release master

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

New in version 2.3.0.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

New in version 2.3.0.

Attributes Documentation

estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')

estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')

evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')

foldCol = Param(parent='undefined', name='foldCol', doc="Param for the column name of user specified fold number. Once this is specified, :py:class:`CrossValidator` won't do random k-fold split. Note that this column should be integer type with range [0, numFolds) and Spark will throw exception on out-of-range fold numbers.")

numFolds = Param(parent='undefined', name='numFolds', doc='number of folds for cross validation')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

3.3. ML 959

Page 964: pyspark Documentation

pyspark Documentation, Release master

TrainValidationSplit

class pyspark.ml.tuning.TrainValidationSplit(*args, **kwargs)Validation for hyper-parameter tuning. Randomly splits the input dataset into train and validation sets, and usesevaluation metric on the validation set to select the best model. Similar to CrossValidator, but only splitsthe set once.

New in version 2.0.0.

Examples

>>> from pyspark.ml.classification import LogisticRegression>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator>>> from pyspark.ml.linalg import Vectors>>> from pyspark.ml.tuning import TrainValidationSplitModel>>> import tempfile>>> dataset = spark.createDataFrame(... [(Vectors.dense([0.0]), 0.0),... (Vectors.dense([0.4]), 1.0),... (Vectors.dense([0.5]), 0.0),... (Vectors.dense([0.6]), 1.0),... (Vectors.dense([1.0]), 1.0)] * 10,... ["features", "label"]).repartition(1)>>> lr = LogisticRegression()>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()>>> evaluator = BinaryClassificationEvaluator()>>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid,→˓evaluator=evaluator,... parallelism=1, seed=42)>>> tvsModel = tvs.fit(dataset)>>> tvsModel.getTrainRatio()0.75>>> tvsModel.validationMetrics[0.5, ...>>> path = tempfile.mkdtemp()>>> model_path = path + "/model">>> tvsModel.write().save(model_path)>>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)>>> tvsModelRead.validationMetrics[0.5, ...>>> evaluator.evaluate(tvsModel.transform(dataset))0.833...>>> evaluator.evaluate(tvsModelRead.transform(dataset))0.833...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.

continues on next page

960 Chapter 3. API Reference

Page 965: pyspark Documentation

pyspark Documentation, Release master

Table 349 – continued from previous pageexplainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

fit(dataset[, params]) Fits a model to the input dataset with optional param-eters.

fitMultiple(dataset, paramMaps) Fits a model to the input dataset for each param mapin paramMaps.

getCollectSubModels() Gets the value of collectSubModels or its defaultvalue.

getEstimator() Gets the value of estimator or its default value.getEstimatorParamMaps() Gets the value of estimatorParamMaps or its default

value.getEvaluator() Gets the value of evaluator or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParallelism() Gets the value of parallelism or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getTrainRatio() Gets the value of trainRatio or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setCollectSubModels(value) Sets the value of collectSubModels.setEstimator(value) Sets the value of estimator.setEstimatorParamMaps(value) Sets the value of estimatorParamMaps.setEvaluator(value) Sets the value of evaluator.setParallelism(value) Sets the value of parallelism.setParams(*args, **kwargs) setParams(self, *, estimator=None, estimator-

ParamMaps=None, evaluator=None, trainRa-tio=0.75, parallelism=1, collectSubModels=False,seed=None): Sets params for the train validationsplit.

setSeed(value) Sets the value of seed.setTrainRatio(value) Sets the value of trainRatio.

continues on next page

3.3. ML 961

Page 966: pyspark Documentation

pyspark Documentation, Release master

Table 349 – continued from previous pagewrite() Returns an MLWriter instance for this ML instance.

Attributes

collectSubModelsestimatorestimatorParamMapsevaluatorparallelismparams Returns all params ordered by name.seedtrainRatio

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies createsa deep copy of the embedded paramMap, and copies the embedded and extra parameters over.

New in version 2.0.0.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

TrainValidationSplit Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

fit(dataset, params=None)Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters

962 Chapter 3. API Reference

Page 967: pyspark Documentation

pyspark Documentation, Release master

dataset [pyspark.sql.DataFrame] input dataset.

params [dict or list or tuple, optional] an optional param map that overrides embeddedparams. If a list/tuple of param maps is given, this calls fit on each param map and returnsa list of models.

Returns

Transformer or a list of Transformer fitted model(s)

fitMultiple(dataset, paramMaps)Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset.

paramMaps [collections.abc.Sequence] A Sequence of param maps.

Returns

_FitMultipleIterator A thread safe iterable which contains one model for eachparam map. Each call to next(modelIterator) will return (index, model) where model wasfit using paramMaps[index]. index values may not be sequential.

getCollectSubModels()Gets the value of collectSubModels or its default value.

getEstimator()Gets the value of estimator or its default value.

New in version 2.0.0.

getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.

New in version 2.0.0.

getEvaluator()Gets the value of evaluator or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParallelism()Gets the value of parallelism or its default value.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getTrainRatio()Gets the value of trainRatio or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

3.3. ML 963

Page 968: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

New in version 2.3.0.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setCollectSubModels(value)Sets the value of collectSubModels.

setEstimator(value)Sets the value of estimator.

New in version 2.0.0.

setEstimatorParamMaps(value)Sets the value of estimatorParamMaps.

New in version 2.0.0.

setEvaluator(value)Sets the value of evaluator.

New in version 2.0.0.

setParallelism(value)Sets the value of parallelism.

setParams(*args, **kwargs)setParams(self, *, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, paral-lelism=1, collectSubModels=False, seed=None): Sets params for the train validation split.

New in version 2.0.0.

setSeed(value)Sets the value of seed.

setTrainRatio(value)Sets the value of trainRatio.

New in version 2.0.0.

write()Returns an MLWriter instance for this ML instance.

New in version 2.3.0.

964 Chapter 3. API Reference

Page 969: pyspark Documentation

pyspark Documentation, Release master

Attributes Documentation

collectSubModels = Param(parent='undefined', name='collectSubModels', doc='Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.')

estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')

estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')

evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')

parallelism = Param(parent='undefined', name='parallelism', doc='the number of threads to use when running parallel algorithms (>= 1).')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

trainRatio = Param(parent='undefined', name='trainRatio', doc='Param for ratio between train and validation data. Must be between 0 and 1.')

TrainValidationSplitModel

class pyspark.ml.tuning.TrainValidationSplitModel(bestModel, validation-Metrics=None, subModels=None)

Model from train validation split.

New in version 2.0.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with a randomly gen-erated uid and some extra params.

explainParam(param) Explains a single param and returns its name, doc,and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getEstimator() Gets the value of estimator or its default value.getEstimatorParamMaps() Gets the value of estimatorParamMaps or its default

value.getEvaluator() Gets the value of evaluator or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getSeed() Gets the value of seed or its default value.getTrainRatio() Gets the value of trainRatio or its default value.hasDefault(param) Checks whether a param has a default value.

continues on next page

3.3. ML 965

Page 970: pyspark Documentation

pyspark Documentation, Release master

Table 351 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.transform(dataset[, params]) Transforms the input dataset with optional parame-

ters.write() Returns an MLWriter instance for this ML instance.

Attributes

estimatorestimatorParamMapsevaluatorparams Returns all params ordered by name.seedtrainRatio

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with a randomly generated uid and some extra params. This copies theunderlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded andextra parameters over. And, this creates a shallow copy of the validationMetrics. It does not copy the extraParams into the subModels.

New in version 2.0.0.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

TrainValidationSplitModel Copy of this instance

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extra

966 Chapter 3. API Reference

Page 971: pyspark Documentation

pyspark Documentation, Release master

values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getEstimator()Gets the value of estimator or its default value.

New in version 2.0.0.

getEstimatorParamMaps()Gets the value of estimatorParamMaps or its default value.

New in version 2.0.0.

getEvaluator()Gets the value of evaluator or its default value.

New in version 2.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getSeed()Gets the value of seed or its default value.

getTrainRatio()Gets the value of trainRatio or its default value.

New in version 2.0.0.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

New in version 2.3.0.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

3.3. ML 967

Page 972: pyspark Documentation

pyspark Documentation, Release master

transform(dataset, params=None)Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters

dataset [pyspark.sql.DataFrame] input dataset

params [dict, optional] an optional param map that overrides embedded params.

Returns

pyspark.sql.DataFrame transformed dataset

write()Returns an MLWriter instance for this ML instance.

New in version 2.3.0.

Attributes Documentation

estimator = Param(parent='undefined', name='estimator', doc='estimator to be cross-validated')

estimatorParamMaps = Param(parent='undefined', name='estimatorParamMaps', doc='estimator param maps')

evaluator = Param(parent='undefined', name='evaluator', doc='evaluator used to select hyper-parameters that maximize the validator metric')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

seed = Param(parent='undefined', name='seed', doc='random seed.')

trainRatio = Param(parent='undefined', name='trainRatio', doc='Param for ratio between train and validation data. Must be between 0 and 1.')

3.3.12 Evaluation

Evaluator() Base class for evaluators that compute metrics from pre-dictions.

BinaryClassificationEvaluator(*args,**kwargs)

Evaluator for binary classification, which expects inputcolumns rawPrediction, label and an optional weightcolumn.

RegressionEvaluator(*args, **kwargs) Evaluator for Regression, which expects input columnsprediction, label and an optional weight column.

MulticlassClassificationEvaluator(*args,. . . )

Evaluator for Multiclass Classification, which expectsinput columns: prediction, label, weight (optional) andprobabilityCol (only for logLoss).

MultilabelClassificationEvaluator(*args,. . . )

Evaluator for Multilabel Classification, which expectstwo input columns: prediction and label.

ClusteringEvaluator(*args, **kwargs) Evaluator for Clustering results, which expects two in-put columns: prediction and features.

RankingEvaluator(*args, **kwargs) Evaluator for Ranking, which expects two inputcolumns: prediction and label.

968 Chapter 3. API Reference

Page 973: pyspark Documentation

pyspark Documentation, Release master

Evaluator

class pyspark.ml.evaluation.EvaluatorBase class for evaluators that compute metrics from predictions.

New in version 1.4.0.

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getOrDefault(param) Gets the value of a param in the user-supplied parammap or its default value.

getParam(paramName) Gets a param by its name.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isLargerBetter() Indicates whether the metric returned by

evaluate() should be maximized (True, de-fault) or minimized (False).

isSet(param) Checks whether a param is explicitly set by user.set(param, value) Sets a parameter in the embedded param map.

Attributes

params Returns all params ordered by name.

3.3. ML 969

Page 974: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. The default implementationcreates a shallow copy using copy.copy(), and then copies the embedded and extra parameters overand returns the copy. Subclasses should override this method if the default approach is not sufficient.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

Params Copy of this instance

evaluate(dataset, params=None)Evaluates the output with optional parameters.

New in version 1.4.0.

Parameters

dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions

params [dict, optional] an optional param map that overrides embedded params

Returns

float metric

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

hasDefault(param)Checks whether a param has a default value.

970 Chapter 3. API Reference

Page 975: pyspark Documentation

pyspark Documentation, Release master

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.

New in version 1.5.0.

isSet(param)Checks whether a param is explicitly set by user.

set(param, value)Sets a parameter in the embedded param map.

Attributes Documentation

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

BinaryClassificationEvaluator

class pyspark.ml.evaluation.BinaryClassificationEvaluator(*args, **kwargs)Evaluator for binary classification, which expects input columns rawPrediction, label and an optional weightcolumn. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or oftype vector (length-2 vector of raw predictions, scores, or label probabilities).

New in version 1.4.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]),... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0),→˓(0.8, 1.0)])>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])...>>> evaluator = BinaryClassificationEvaluator()>>> evaluator.setRawPredictionCol("raw")BinaryClassificationEvaluator...>>> evaluator.evaluate(dataset)0.70...>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})0.83...>>> bce_path = temp_path + "/bce">>> evaluator.save(bce_path)>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)>>> str(evaluator2.getRawPredictionCol())'raw'>>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]),→˓x[1], x[2]),

(continues on next page)

3.3. ML 971

Page 976: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),

... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label",→˓"weight"])...>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol=→˓"weight")>>> evaluator.evaluate(dataset)0.70...>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})0.82...>>> evaluator.getNumBins()1000

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getLabelCol() Gets the value of labelCol or its default value.getMetricName() Gets the value of metricName or its default value.getNumBins() Gets the value of numBins or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getRawPredictionCol() Gets the value of rawPredictionCol or its default

value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isLargerBetter() Indicates whether the metric returned by

evaluate() should be maximized (True, de-fault) or minimized (False).

isSet(param) Checks whether a param is explicitly set by user.continues on next page

972 Chapter 3. API Reference

Page 977: pyspark Documentation

pyspark Documentation, Release master

Table 356 – continued from previous pageload(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setLabelCol(value) Sets the value of labelCol.setMetricName(value) Sets the value of metricName.setNumBins(value) Sets the value of numBins.setParams(self, \*[, rawPredictionCol, . . . ]) Sets params for binary classification evaluator.setRawPredictionCol(value) Sets the value of rawPredictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

labelColmetricNamenumBinsparams Returns all params ordered by name.rawPredictionColweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset, params=None)Evaluates the output with optional parameters.

New in version 1.4.0.

Parameters

dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions

params [dict, optional] an optional param map that overrides embedded params

Returns

float metric

3.3. ML 973

Page 978: pyspark Documentation

pyspark Documentation, Release master

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getLabelCol()Gets the value of labelCol or its default value.

getMetricName()Gets the value of metricName or its default value.

New in version 1.4.0.

getNumBins()Gets the value of numBins or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getRawPredictionCol()Gets the value of rawPredictionCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.

New in version 1.5.0.

isSet(param)Checks whether a param is explicitly set by user.

974 Chapter 3. API Reference

Page 979: pyspark Documentation

pyspark Documentation, Release master

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setLabelCol(value)Sets the value of labelCol.

setMetricName(value)Sets the value of metricName.

New in version 1.4.0.

setNumBins(value)Sets the value of numBins.

New in version 3.0.0.

setParams(self, \*, rawPredictionCol="rawPrediction", labelCol="label", metric-Name="areaUnderROC", weightCol=None, numBins=1000)

Sets params for binary classification evaluator.

New in version 1.4.0.

setRawPredictionCol(value)Sets the value of rawPredictionCol.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (areaUnderROC|areaUnderPR)')

numBins = Param(parent='undefined', name='numBins', doc='Number of bins to down-sample the curves (ROC curve, PR curve) in area computation. If 0, no down-sampling will occur. Must be >= 0.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

rawPredictionCol = Param(parent='undefined', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

3.3. ML 975

Page 980: pyspark Documentation

pyspark Documentation, Release master

RegressionEvaluator

class pyspark.ml.evaluation.RegressionEvaluator(*args, **kwargs)Evaluator for Regression, which expects input columns prediction, label and an optional weight column.

New in version 1.4.0.

Examples

>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])...>>> evaluator = RegressionEvaluator()>>> evaluator.setPredictionCol("raw")RegressionEvaluator...>>> evaluator.evaluate(dataset)2.842...>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})0.993...>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})2.649...>>> re_path = temp_path + "/re">>> evaluator.save(re_path)>>> evaluator2 = RegressionEvaluator.load(re_path)>>> str(evaluator2.getPredictionCol())'raw'>>> scoreAndLabelsAndWeight = [(-28.98343821, -27.0, 1.0), (20.21491975, 21.5, 0.→˓8),... (-25.98418959, -22.0, 1.0), (30.69731842, 33.0, 0.6), (74.69283752, 71.0, 0.→˓2)]>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label",→˓"weight"])...>>> evaluator = RegressionEvaluator(predictionCol="raw", weightCol="weight")>>> evaluator.evaluate(dataset)2.740...>>> evaluator.getThroughOrigin()False

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

continues on next page

976 Chapter 3. API Reference

Page 981: pyspark Documentation

pyspark Documentation, Release master

Table 358 – continued from previous pageextractParamMap([extra]) Extracts the embedded default param values and

user-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getLabelCol() Gets the value of labelCol or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getThroughOrigin() Gets the value of throughOrigin or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isLargerBetter() Indicates whether the metric returned by

evaluate() should be maximized (True, de-fault) or minimized (False).

isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setLabelCol(value) Sets the value of labelCol.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for regression evaluator.setPredictionCol(value) Sets the value of predictionCol.setThroughOrigin(value) Sets the value of throughOrigin.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

labelColmetricNameparams Returns all params ordered by name.predictionColthroughOriginweightCol

3.3. ML 977

Page 982: pyspark Documentation

pyspark Documentation, Release master

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset, params=None)Evaluates the output with optional parameters.

New in version 1.4.0.

Parameters

dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions

params [dict, optional] an optional param map that overrides embedded params

Returns

float metric

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getLabelCol()Gets the value of labelCol or its default value.

getMetricName()Gets the value of metricName or its default value.

New in version 1.4.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

978 Chapter 3. API Reference

Page 983: pyspark Documentation

pyspark Documentation, Release master

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getThroughOrigin()Gets the value of throughOrigin or its default value.

New in version 3.0.0.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.

New in version 1.5.0.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setLabelCol(value)Sets the value of labelCol.

setMetricName(value)Sets the value of metricName.

New in version 1.4.0.

setParams(self, \*, predictionCol="prediction", labelCol="label", metricName="rmse", weight-Col=None, throughOrigin=False)

Sets params for regression evaluator.

New in version 1.4.0.

setPredictionCol(value)Sets the value of predictionCol.

setThroughOrigin(value)Sets the value of throughOrigin.

New in version 3.0.0.

3.3. ML 979

Page 984: pyspark Documentation

pyspark Documentation, Release master

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation - one of:\n rmse - root mean squared error (default)\n mse - mean squared error\n r2 - r^2 metric\n mae - mean absolute error\n var - explained variance.')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

throughOrigin = Param(parent='undefined', name='throughOrigin', doc='whether the regression is through the origin.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

MulticlassClassificationEvaluator

class pyspark.ml.evaluation.MulticlassClassificationEvaluator(*args, **kwargs)Evaluator for Multiclass Classification, which expects input columns: prediction, label, weight (optional) andprobabilityCol (only for logLoss).

New in version 1.5.0.

Examples

>>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])>>> evaluator = MulticlassClassificationEvaluator()>>> evaluator.setPredictionCol("prediction")MulticlassClassificationEvaluator...>>> evaluator.evaluate(dataset)0.66...>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})0.66...>>> evaluator.evaluate(dataset, {evaluator.metricName: "truePositiveRateByLabel",... evaluator.metricLabel: 1.0})0.75...>>> evaluator.setMetricName("hammingLoss")MulticlassClassificationEvaluator...>>> evaluator.evaluate(dataset)0.33...>>> mce_path = temp_path + "/mce">>> evaluator.save(mce_path)>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)>>> str(evaluator2.getPredictionCol())'prediction'>>> scoreAndLabelsAndWeight = [(0.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0),

(continues on next page)

980 Chapter 3. API Reference

Page 985: pyspark Documentation

pyspark Documentation, Release master

(continued from previous page)

... (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),

... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)]>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["prediction", "label→˓", "weight"])>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",... weightCol="weight")>>> evaluator.evaluate(dataset)0.66...>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})0.66...>>> predictionAndLabelsWithProbabilities = [... (1.0, 1.0, 1.0, [0.1, 0.8, 0.1]), (0.0, 2.0, 1.0, [0.9, 0.05, 0.05]),... (0.0, 0.0, 1.0, [0.8, 0.2, 0.0]), (1.0, 1.0, 1.0, [0.3, 0.65, 0.05])]>>> dataset = spark.createDataFrame(predictionAndLabelsWithProbabilities, [→˓"prediction",... "label", "weight", "probability"])>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",... probabilityCol="probability")>>> evaluator.setMetricName("logLoss")MulticlassClassificationEvaluator...>>> evaluator.evaluate(dataset)0.9682...

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getBeta() Gets the value of beta or its default value.getEps() Gets the value of eps or its default value.getLabelCol() Gets the value of labelCol or its default value.getMetricLabel() Gets the value of metricLabel or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getProbabilityCol() Gets the value of probabilityCol or its default value.

continues on next page

3.3. ML 981

Page 986: pyspark Documentation

pyspark Documentation, Release master

Table 360 – continued from previous pagegetWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isLargerBetter() Indicates whether the metric returned by

evaluate() should be maximized (True, de-fault) or minimized (False).

isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setBeta(value) Sets the value of beta.setEps(value) Sets the value of eps.setLabelCol(value) Sets the value of labelCol.setMetricLabel(value) Sets the value of metricLabel.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for multiclass classification evaluator.setPredictionCol(value) Sets the value of predictionCol.setProbabilityCol(value) Sets the value of probabilityCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

Attributes

betaepslabelColmetricLabelmetricNameparams Returns all params ordered by name.predictionColprobabilityColweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

982 Chapter 3. API Reference

Page 987: pyspark Documentation

pyspark Documentation, Release master

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset, params=None)Evaluates the output with optional parameters.

New in version 1.4.0.

Parameters

dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions

params [dict, optional] an optional param map that overrides embedded params

Returns

float metric

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getBeta()Gets the value of beta or its default value.

New in version 3.0.0.

getEps()Gets the value of eps or its default value.

New in version 3.0.0.

getLabelCol()Gets the value of labelCol or its default value.

getMetricLabel()Gets the value of metricLabel or its default value.

New in version 3.0.0.

getMetricName()Gets the value of metricName or its default value.

New in version 1.5.0.

3.3. ML 983

Page 988: pyspark Documentation

pyspark Documentation, Release master

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getProbabilityCol()Gets the value of probabilityCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.

New in version 1.5.0.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setBeta(value)Sets the value of beta.

New in version 3.0.0.

setEps(value)Sets the value of eps.

New in version 3.0.0.

setLabelCol(value)Sets the value of labelCol.

setMetricLabel(value)Sets the value of metricLabel.

New in version 3.0.0.

984 Chapter 3. API Reference

Page 989: pyspark Documentation

pyspark Documentation, Release master

setMetricName(value)Sets the value of metricName.

New in version 1.5.0.

setParams(self, \*, predictionCol="prediction", labelCol="label", metricName="f1", weight-Col=None, metricLabel=0.0, beta=1.0, probabilityCol="probability", eps=1e-15)

Sets params for multiclass classification evaluator.

New in version 1.5.0.

setPredictionCol(value)Sets the value of predictionCol.

setProbabilityCol(value)Sets the value of probabilityCol.

New in version 3.0.0.

setWeightCol(value)Sets the value of weightCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

beta = Param(parent='undefined', name='beta', doc='The beta value used in weightedFMeasure|fMeasureByLabel. Must be > 0. The default value is 1.')

eps = Param(parent='undefined', name='eps', doc='log-loss is undefined for p=0 or p=1, so probabilities are clipped to max(eps, min(1 - eps, p)). Must be in range (0, 0.5). The default value is 1e-15.')

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

metricLabel = Param(parent='undefined', name='metricLabel', doc='The class whose metric will be computed in truePositiveRateByLabel|falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel. Must be >= 0. The default value is 0.')

metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate| weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| logLoss|hammingLoss)')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

probabilityCol = Param(parent='undefined', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

MultilabelClassificationEvaluator

class pyspark.ml.evaluation.MultilabelClassificationEvaluator(*args, **kwargs)Evaluator for Multilabel Classification, which expects two input columns: prediction and label.

New in version 3.0.0.

3.3. ML 985

Page 990: pyspark Documentation

pyspark Documentation, Release master

Notes

Experimental

Examples

>>> scoreAndLabels = [([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])...>>> evaluator = MultilabelClassificationEvaluator()>>> evaluator.setPredictionCol("prediction")MultilabelClassificationEvaluator...>>> evaluator.evaluate(dataset)0.63...>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})0.54...>>> mlce_path = temp_path + "/mlce">>> evaluator.save(mlce_path)>>> evaluator2 = MultilabelClassificationEvaluator.load(mlce_path)>>> str(evaluator2.getPredictionCol())'prediction'

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset[, params]) Evaluates the output with optional parameters.explainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getLabelCol() Gets the value of labelCol or its default value.getMetricLabel() Gets the value of metricLabel or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.hasDefault(param) Checks whether a param has a default value.

continues on next page

986 Chapter 3. API Reference

Page 991: pyspark Documentation

pyspark Documentation, Release master

Table 362 – continued from previous pagehasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isLargerBetter() Indicates whether the metric returned by

evaluate() should be maximized (True, de-fault) or minimized (False).

isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setLabelCol(value) Sets the value of labelCol.setMetricLabel(value) Sets the value of metricLabel.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for multilabel classification evaluator.setPredictionCol(value) Sets the value of predictionCol.write() Returns an MLWriter instance for this ML instance.

Attributes

labelColmetricLabelmetricNameparams Returns all params ordered by name.predictionCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset, params=None)Evaluates the output with optional parameters.

New in version 1.4.0.

Parameters

3.3. ML 987

Page 992: pyspark Documentation

pyspark Documentation, Release master

dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions

params [dict, optional] an optional param map that overrides embedded params

Returns

float metric

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

dict merged param map

getLabelCol()Gets the value of labelCol or its default value.

getMetricLabel()Gets the value of metricLabel or its default value.

New in version 3.0.0.

getMetricName()Gets the value of metricName or its default value.

New in version 3.0.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.

988 Chapter 3. API Reference

Page 993: pyspark Documentation

pyspark Documentation, Release master

New in version 1.5.0.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setLabelCol(value)Sets the value of labelCol.

New in version 3.0.0.

setMetricLabel(value)Sets the value of metricLabel.

New in version 3.0.0.

setMetricName(value)Sets the value of metricName.

New in version 3.0.0.

setParams(self, \*, predictionCol="prediction", labelCol="label", metricName="f1Measure", metri-cLabel=0.0)

Sets params for multilabel classification evaluator.

New in version 3.0.0.

setPredictionCol(value)Sets the value of predictionCol.

New in version 3.0.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

labelCol = Param(parent='undefined', name='labelCol', doc='label column name.')

metricLabel = Param(parent='undefined', name='metricLabel', doc='The class whose metric will be computed in precisionByLabel|recallByLabel|f1MeasureByLabel. Must be >= 0. The default value is 0.')

metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|microRecall|microF1Measure)')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

3.3. ML 989

Page 994: pyspark Documentation

pyspark Documentation, Release master

ClusteringEvaluator

class pyspark.ml.evaluation.ClusteringEvaluator(*args, **kwargs)Evaluator for Clustering results, which expects two input columns: prediction and features. The metric computesthe Silhouette measure using the squared Euclidean distance.

The Silhouette is a measure for the validation of the consistency within clusters. It ranges between 1 and -1,where a value close to 1 means that the points in a cluster are close to the other points in the same cluster andfar from the points of the other clusters.

New in version 2.3.0.

Examples

>>> from pyspark.ml.linalg import Vectors>>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])>>> dataset = spark.createDataFrame(featureAndPredictions, ["features",→˓"prediction"])...>>> evaluator = ClusteringEvaluator()>>> evaluator.setPredictionCol("prediction")ClusteringEvaluator...>>> evaluator.evaluate(dataset)0.9079...>>> featureAndPredictionsWithWeight = map(lambda x: (Vectors.dense(x[0]), x[1],→˓x[2]),... [([0.0, 0.5], 0.0, 2.5), ([0.5, 0.0], 0.0, 2.5), ([10.0, 11.0], 1.0, 2.5),... ([10.5, 11.5], 1.0, 2.5), ([1.0, 1.0], 0.0, 2.5), ([8.0, 6.0], 1.0, 2.5)])>>> dataset = spark.createDataFrame(... featureAndPredictionsWithWeight, ["features", "prediction", "weight"])>>> evaluator = ClusteringEvaluator()>>> evaluator.setPredictionCol("prediction")ClusteringEvaluator...>>> evaluator.setWeightCol("weight")ClusteringEvaluator...>>> evaluator.evaluate(dataset)0.9079...>>> ce_path = temp_path + "/ce">>> evaluator.save(ce_path)>>> evaluator2 = ClusteringEvaluator.load(ce_path)>>> str(evaluator2.getPredictionCol())'prediction'

Methods

clear(param) Clears a param from the param map if it has beenexplicitly set.

copy([extra]) Creates a copy of this instance with the same uid andsome extra params.

evaluate(dataset[, params]) Evaluates the output with optional parameters.continues on next page

990 Chapter 3. API Reference

Page 995: pyspark Documentation

pyspark Documentation, Release master

Table 364 – continued from previous pageexplainParam(param) Explains a single param and returns its name, doc,

and optional default value and user-supplied value ina string.

explainParams() Returns the documentation of all params with theiroptionally default values and user-supplied values.

extractParamMap([extra]) Extracts the embedded default param values anduser-supplied values, and then merges them with ex-tra values from input into a flat param map, wherethe latter value is used if there exist conflicts, i.e.,with ordering: default param values < user-suppliedvalues < extra.

getDistanceMeasure() Gets the value of distanceMeasuregetFeaturesCol() Gets the value of featuresCol or its default value.getMetricName() Gets the value of metricName or its default value.getOrDefault(param) Gets the value of a param in the user-supplied param

map or its default value.getParam(paramName) Gets a param by its name.getPredictionCol() Gets the value of predictionCol or its default value.getWeightCol() Gets the value of weightCol or its default value.hasDefault(param) Checks whether a param has a default value.hasParam(paramName) Tests whether this instance contains a param with a

given (string) name.isDefined(param) Checks whether a param is explicitly set by user or

has a default value.isLargerBetter() Indicates whether the metric returned by

evaluate() should be maximized (True, de-fault) or minimized (False).

isSet(param) Checks whether a param is explicitly set by user.load(path) Reads an ML instance from the input path, a shortcut

of read().load(path).read() Returns an MLReader instance for this class.save(path) Save this ML instance to the given path, a shortcut of

‘write().save(path)’.set(param, value) Sets a parameter in the embedded param map.setDistanceMeasure(value) Sets the value of distanceMeasure.setFeaturesCol(value) Sets the value of featuresCol.setMetricName(value) Sets the value of metricName.setParams(self, \*[, predictionCol, . . . ]) Sets params for clustering evaluator.setPredictionCol(value) Sets the value of predictionCol.setWeightCol(value) Sets the value of weightCol.write() Returns an MLWriter instance for this ML instance.

3.3. ML 991

Page 996: pyspark Documentation

pyspark Documentation, Release master

Attributes

distanceMeasurefeaturesColmetricNameparams Returns all params ordered by name.predictionColweightCol

Methods Documentation

clear(param)Clears a param from the param map if it has been explicitly set.

copy(extra=None)Creates a copy of this instance with the same uid and some extra params. This implementation first callsParams.copy and then make a copy of the companion Java pipeline component with extra params. So boththe Python wrapper and the Java pipeline component get copied.

Parameters

extra [dict, optional] Extra parameters to copy to the new instance

Returns

JavaParams Copy of this instance

evaluate(dataset, params=None)Evaluates the output with optional parameters.

New in version 1.4.0.

Parameters

dataset [pyspark.sql.DataFrame] a dataset that contains labels/observations andpredictions

params [dict, optional] an optional param map that overrides embedded params

Returns

float metric

explainParam(param)Explains a single param and returns its name, doc, and optional default value and user-supplied value in astring.

explainParams()Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)Extracts the embedded default param values and user-supplied values, and then merges them with extravalues from input into a flat param map, where the latter value is used if there exist conflicts, i.e., withordering: default param values < user-supplied values < extra.

Parameters

extra [dict, optional] extra param values

Returns

992 Chapter 3. API Reference

Page 997: pyspark Documentation

pyspark Documentation, Release master

dict merged param map

getDistanceMeasure()Gets the value of distanceMeasure

New in version 2.4.0.

getFeaturesCol()Gets the value of featuresCol or its default value.

getMetricName()Gets the value of metricName or its default value.

New in version 2.3.0.

getOrDefault(param)Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither isset.

getParam(paramName)Gets a param by its name.

getPredictionCol()Gets the value of predictionCol or its default value.

getWeightCol()Gets the value of weightCol or its default value.

hasDefault(param)Checks whether a param has a default value.

hasParam(paramName)Tests whether this instance contains a param with a given (string) name.

isDefined(param)Checks whether a param is explicitly set by user or has a default value.

isLargerBetter()Indicates whether the metric returned by evaluate() should be maximized (True, default) or minimized(False). A given evaluator may support multiple metrics which may be maximized or minimized.

New in version 1.5.0.

isSet(param)Checks whether a param is explicitly set by user.

classmethod load(path)Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()Returns an MLReader instance for this class.

save(path)Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)Sets a parameter in the embedded param map.

setDistanceMeasure(value)Sets the value of distanceMeasure.

New in version 2.4.0.

setFeaturesCol(value)Sets the value of featuresCol.

3.3. ML 993

Page 998: pyspark Documentation

pyspark Documentation, Release master

setMetricName(value)Sets the value of metricName.

New in version 2.3.0.

setParams(self, \*, predictionCol="prediction", featuresCol="features", metricName="silhouette",distanceMeasure="squaredEuclidean", weightCol=None)

Sets params for clustering evaluator.

New in version 2.3.0.

setPredictionCol(value)Sets the value of predictionCol.

setWeightCol(value)Sets the value of weightCol.

New in version 3.1.0.

write()Returns an MLWriter instance for this ML instance.

Attributes Documentation

distanceMeasure = Param(parent='undefined', name='distanceMeasure', doc="The distance measure. Supported options: 'squaredEuclidean' and 'cosine'.")

featuresCol = Param(parent='undefined', name='featuresCol', doc='features column name.')

metricName = Param(parent='undefined', name='metricName', doc='metric name in evaluation (silhouette)')

paramsReturns all params ordered by name. The default implementation uses dir() to get all attributes of typeParam.

predictionCol = Param(parent='undefined', name='predictionCol', doc='prediction column name.')

weightCol = Param(parent='undefined', name='weightCol', doc='weight column name. If this is not set or empty, we treat all instance weights as 1.0.')

RankingEvaluator

class pyspark.ml.evaluation.RankingEvaluator(*args, **kwargs)Evaluator for Ranking, which expects two input columns: prediction and label.

New in version 3.0.0.

Notes

Experimental

994 Chapter 3. API Reference

Page 999: pyspark Documentation
Page 1000: pyspark Documentation
Page 1001: pyspark Documentation
Page 1002: pyspark Documentation
Page 1003: pyspark Documentation
Page 1004: pyspark Documentation
Page 1005: pyspark Documentation
Page 1006: pyspark Documentation
Page 1007: pyspark Documentation
Page 1008: pyspark Documentation
Page 1009: pyspark Documentation
Page 1010: pyspark Documentation
Page 1011: pyspark Documentation
Page 1012: pyspark Documentation
Page 1013: pyspark Documentation
Page 1014: pyspark Documentation
Page 1015: pyspark Documentation
Page 1016: pyspark Documentation
Page 1017: pyspark Documentation
Page 1018: pyspark Documentation
Page 1019: pyspark Documentation
Page 1020: pyspark Documentation
Page 1021: pyspark Documentation
Page 1022: pyspark Documentation
Page 1023: pyspark Documentation
Page 1024: pyspark Documentation
Page 1025: pyspark Documentation
Page 1026: pyspark Documentation
Page 1027: pyspark Documentation
Page 1028: pyspark Documentation
Page 1029: pyspark Documentation
Page 1030: pyspark Documentation
Page 1031: pyspark Documentation
Page 1032: pyspark Documentation
Page 1033: pyspark Documentation
Page 1034: pyspark Documentation
Page 1035: pyspark Documentation
Page 1036: pyspark Documentation
Page 1037: pyspark Documentation
Page 1038: pyspark Documentation
Page 1039: pyspark Documentation
Page 1040: pyspark Documentation
Page 1041: pyspark Documentation
Page 1042: pyspark Documentation
Page 1043: pyspark Documentation
Page 1044: pyspark Documentation
Page 1045: pyspark Documentation
Page 1046: pyspark Documentation
Page 1047: pyspark Documentation
Page 1048: pyspark Documentation
Page 1049: pyspark Documentation
Page 1050: pyspark Documentation
Page 1051: pyspark Documentation
Page 1052: pyspark Documentation
Page 1053: pyspark Documentation
Page 1054: pyspark Documentation
Page 1055: pyspark Documentation
Page 1056: pyspark Documentation
Page 1057: pyspark Documentation
Page 1058: pyspark Documentation
Page 1059: pyspark Documentation
Page 1060: pyspark Documentation
Page 1061: pyspark Documentation
Page 1062: pyspark Documentation
Page 1063: pyspark Documentation
Page 1064: pyspark Documentation
Page 1065: pyspark Documentation
Page 1066: pyspark Documentation
Page 1067: pyspark Documentation
Page 1068: pyspark Documentation
Page 1069: pyspark Documentation
Page 1070: pyspark Documentation
Page 1071: pyspark Documentation
Page 1072: pyspark Documentation
Page 1073: pyspark Documentation
Page 1074: pyspark Documentation
Page 1075: pyspark Documentation
Page 1076: pyspark Documentation
Page 1077: pyspark Documentation
Page 1078: pyspark Documentation
Page 1079: pyspark Documentation
Page 1080: pyspark Documentation
Page 1081: pyspark Documentation
Page 1082: pyspark Documentation
Page 1083: pyspark Documentation
Page 1084: pyspark Documentation
Page 1085: pyspark Documentation
Page 1086: pyspark Documentation
Page 1087: pyspark Documentation
Page 1088: pyspark Documentation
Page 1089: pyspark Documentation
Page 1090: pyspark Documentation
Page 1091: pyspark Documentation
Page 1092: pyspark Documentation
Page 1093: pyspark Documentation
Page 1094: pyspark Documentation
Page 1095: pyspark Documentation
Page 1096: pyspark Documentation
Page 1097: pyspark Documentation
Page 1098: pyspark Documentation
Page 1099: pyspark Documentation
Page 1100: pyspark Documentation
Page 1101: pyspark Documentation
Page 1102: pyspark Documentation
Page 1103: pyspark Documentation
Page 1104: pyspark Documentation
Page 1105: pyspark Documentation
Page 1106: pyspark Documentation
Page 1107: pyspark Documentation
Page 1108: pyspark Documentation
Page 1109: pyspark Documentation
Page 1110: pyspark Documentation
Page 1111: pyspark Documentation
Page 1112: pyspark Documentation
Page 1113: pyspark Documentation
Page 1114: pyspark Documentation
Page 1115: pyspark Documentation
Page 1116: pyspark Documentation
Page 1117: pyspark Documentation
Page 1118: pyspark Documentation
Page 1119: pyspark Documentation
Page 1120: pyspark Documentation
Page 1121: pyspark Documentation
Page 1122: pyspark Documentation
Page 1123: pyspark Documentation
Page 1124: pyspark Documentation
Page 1125: pyspark Documentation
Page 1126: pyspark Documentation
Page 1127: pyspark Documentation
Page 1128: pyspark Documentation
Page 1129: pyspark Documentation
Page 1130: pyspark Documentation
Page 1131: pyspark Documentation
Page 1132: pyspark Documentation
Page 1133: pyspark Documentation
Page 1134: pyspark Documentation
Page 1135: pyspark Documentation
Page 1136: pyspark Documentation
Page 1137: pyspark Documentation
Page 1138: pyspark Documentation
Page 1139: pyspark Documentation
Page 1140: pyspark Documentation
Page 1141: pyspark Documentation
Page 1142: pyspark Documentation
Page 1143: pyspark Documentation
Page 1144: pyspark Documentation
Page 1145: pyspark Documentation
Page 1146: pyspark Documentation
Page 1147: pyspark Documentation
Page 1148: pyspark Documentation
Page 1149: pyspark Documentation
Page 1150: pyspark Documentation
Page 1151: pyspark Documentation
Page 1152: pyspark Documentation
Page 1153: pyspark Documentation
Page 1154: pyspark Documentation
Page 1155: pyspark Documentation
Page 1156: pyspark Documentation
Page 1157: pyspark Documentation
Page 1158: pyspark Documentation
Page 1159: pyspark Documentation
Page 1160: pyspark Documentation
Page 1161: pyspark Documentation
Page 1162: pyspark Documentation
Page 1163: pyspark Documentation
Page 1164: pyspark Documentation
Page 1165: pyspark Documentation
Page 1166: pyspark Documentation
Page 1167: pyspark Documentation
Page 1168: pyspark Documentation
Page 1169: pyspark Documentation
Page 1170: pyspark Documentation
Page 1171: pyspark Documentation
Page 1172: pyspark Documentation
Page 1173: pyspark Documentation
Page 1174: pyspark Documentation
Page 1175: pyspark Documentation
Page 1176: pyspark Documentation
Page 1177: pyspark Documentation
Page 1178: pyspark Documentation
Page 1179: pyspark Documentation
Page 1180: pyspark Documentation
Page 1181: pyspark Documentation
Page 1182: pyspark Documentation
Page 1183: pyspark Documentation
Page 1184: pyspark Documentation
Page 1185: pyspark Documentation
Page 1186: pyspark Documentation
Page 1187: pyspark Documentation
Page 1188: pyspark Documentation
Page 1189: pyspark Documentation
Page 1190: pyspark Documentation
Page 1191: pyspark Documentation
Page 1192: pyspark Documentation
Page 1193: pyspark Documentation
Page 1194: pyspark Documentation
Page 1195: pyspark Documentation
Page 1196: pyspark Documentation
Page 1197: pyspark Documentation
Page 1198: pyspark Documentation
Page 1199: pyspark Documentation
Page 1200: pyspark Documentation
Page 1201: pyspark Documentation
Page 1202: pyspark Documentation
Page 1203: pyspark Documentation
Page 1204: pyspark Documentation
Page 1205: pyspark Documentation
Page 1206: pyspark Documentation
Page 1207: pyspark Documentation
Page 1208: pyspark Documentation
Page 1209: pyspark Documentation
Page 1210: pyspark Documentation
Page 1211: pyspark Documentation
Page 1212: pyspark Documentation
Page 1213: pyspark Documentation
Page 1214: pyspark Documentation
Page 1215: pyspark Documentation
Page 1216: pyspark Documentation
Page 1217: pyspark Documentation
Page 1218: pyspark Documentation
Page 1219: pyspark Documentation
Page 1220: pyspark Documentation
Page 1221: pyspark Documentation
Page 1222: pyspark Documentation
Page 1223: pyspark Documentation
Page 1224: pyspark Documentation
Page 1225: pyspark Documentation
Page 1226: pyspark Documentation
Page 1227: pyspark Documentation
Page 1228: pyspark Documentation
Page 1229: pyspark Documentation
Page 1230: pyspark Documentation
Page 1231: pyspark Documentation
Page 1232: pyspark Documentation
Page 1233: pyspark Documentation
Page 1234: pyspark Documentation
Page 1235: pyspark Documentation
Page 1236: pyspark Documentation
Page 1237: pyspark Documentation
Page 1238: pyspark Documentation
Page 1239: pyspark Documentation
Page 1240: pyspark Documentation
Page 1241: pyspark Documentation
Page 1242: pyspark Documentation
Page 1243: pyspark Documentation
Page 1244: pyspark Documentation
Page 1245: pyspark Documentation
Page 1246: pyspark Documentation
Page 1247: pyspark Documentation
Page 1248: pyspark Documentation
Page 1249: pyspark Documentation
Page 1250: pyspark Documentation
Page 1251: pyspark Documentation
Page 1252: pyspark Documentation
Page 1253: pyspark Documentation
Page 1254: pyspark Documentation
Page 1255: pyspark Documentation
Page 1256: pyspark Documentation
Page 1257: pyspark Documentation
Page 1258: pyspark Documentation
Page 1259: pyspark Documentation
Page 1260: pyspark Documentation
Page 1261: pyspark Documentation
Page 1262: pyspark Documentation
Page 1263: pyspark Documentation
Page 1264: pyspark Documentation
Page 1265: pyspark Documentation
Page 1266: pyspark Documentation
Page 1267: pyspark Documentation
Page 1268: pyspark Documentation
Page 1269: pyspark Documentation
Page 1270: pyspark Documentation
Page 1271: pyspark Documentation
Page 1272: pyspark Documentation
Page 1273: pyspark Documentation
Page 1274: pyspark Documentation
Page 1275: pyspark Documentation
Page 1276: pyspark Documentation
Page 1277: pyspark Documentation
Page 1278: pyspark Documentation
Page 1279: pyspark Documentation
Page 1280: pyspark Documentation
Page 1281: pyspark Documentation
Page 1282: pyspark Documentation
Page 1283: pyspark Documentation
Page 1284: pyspark Documentation
Page 1285: pyspark Documentation
Page 1286: pyspark Documentation
Page 1287: pyspark Documentation
Page 1288: pyspark Documentation
Page 1289: pyspark Documentation
Page 1290: pyspark Documentation
Page 1291: pyspark Documentation
Page 1292: pyspark Documentation
Page 1293: pyspark Documentation
Page 1294: pyspark Documentation
Page 1295: pyspark Documentation
Page 1296: pyspark Documentation
Page 1297: pyspark Documentation
Page 1298: pyspark Documentation
Page 1299: pyspark Documentation
Page 1300: pyspark Documentation
Page 1301: pyspark Documentation
Page 1302: pyspark Documentation
Page 1303: pyspark Documentation
Page 1304: pyspark Documentation
Page 1305: pyspark Documentation
Page 1306: pyspark Documentation
Page 1307: pyspark Documentation
Page 1308: pyspark Documentation
Page 1309: pyspark Documentation
Page 1310: pyspark Documentation
Page 1311: pyspark Documentation
Page 1312: pyspark Documentation
Page 1313: pyspark Documentation
Page 1314: pyspark Documentation
Page 1315: pyspark Documentation
Page 1316: pyspark Documentation
Page 1317: pyspark Documentation
Page 1318: pyspark Documentation
Page 1319: pyspark Documentation
Page 1320: pyspark Documentation
Page 1321: pyspark Documentation
Page 1322: pyspark Documentation
Page 1323: pyspark Documentation
Page 1324: pyspark Documentation
Page 1325: pyspark Documentation
Page 1326: pyspark Documentation
Page 1327: pyspark Documentation
Page 1328: pyspark Documentation
Page 1329: pyspark Documentation
Page 1330: pyspark Documentation
Page 1331: pyspark Documentation
Page 1332: pyspark Documentation
Page 1333: pyspark Documentation
Page 1334: pyspark Documentation
Page 1335: pyspark Documentation
Page 1336: pyspark Documentation
Page 1337: pyspark Documentation
Page 1338: pyspark Documentation
Page 1339: pyspark Documentation
Page 1340: pyspark Documentation
Page 1341: pyspark Documentation
Page 1342: pyspark Documentation
Page 1343: pyspark Documentation
Page 1344: pyspark Documentation
Page 1345: pyspark Documentation
Page 1346: pyspark Documentation
Page 1347: pyspark Documentation
Page 1348: pyspark Documentation
Page 1349: pyspark Documentation
Page 1350: pyspark Documentation
Page 1351: pyspark Documentation
Page 1352: pyspark Documentation
Page 1353: pyspark Documentation
Page 1354: pyspark Documentation
Page 1355: pyspark Documentation
Page 1356: pyspark Documentation
Page 1357: pyspark Documentation
Page 1358: pyspark Documentation
Page 1359: pyspark Documentation
Page 1360: pyspark Documentation
Page 1361: pyspark Documentation
Page 1362: pyspark Documentation
Page 1363: pyspark Documentation
Page 1364: pyspark Documentation
Page 1365: pyspark Documentation
Page 1366: pyspark Documentation
Page 1367: pyspark Documentation
Page 1368: pyspark Documentation
Page 1369: pyspark Documentation
Page 1370: pyspark Documentation
Page 1371: pyspark Documentation
Page 1372: pyspark Documentation
Page 1373: pyspark Documentation
Page 1374: pyspark Documentation
Page 1375: pyspark Documentation
Page 1376: pyspark Documentation
Page 1377: pyspark Documentation
Page 1378: pyspark Documentation
Page 1379: pyspark Documentation
Page 1380: pyspark Documentation
Page 1381: pyspark Documentation
Page 1382: pyspark Documentation
Page 1383: pyspark Documentation
Page 1384: pyspark Documentation
Page 1385: pyspark Documentation
Page 1386: pyspark Documentation
Page 1387: pyspark Documentation
Page 1388: pyspark Documentation
Page 1389: pyspark Documentation
Page 1390: pyspark Documentation
Page 1391: pyspark Documentation
Page 1392: pyspark Documentation
Page 1393: pyspark Documentation
Page 1394: pyspark Documentation
Page 1395: pyspark Documentation
Page 1396: pyspark Documentation
Page 1397: pyspark Documentation
Page 1398: pyspark Documentation
Page 1399: pyspark Documentation
Page 1400: pyspark Documentation
Page 1401: pyspark Documentation
Page 1402: pyspark Documentation
Page 1403: pyspark Documentation
Page 1404: pyspark Documentation
Page 1405: pyspark Documentation
Page 1406: pyspark Documentation
Page 1407: pyspark Documentation
Page 1408: pyspark Documentation
Page 1409: pyspark Documentation
Page 1410: pyspark Documentation
Page 1411: pyspark Documentation
Page 1412: pyspark Documentation
Page 1413: pyspark Documentation
Page 1414: pyspark Documentation
Page 1415: pyspark Documentation
Page 1416: pyspark Documentation
Page 1417: pyspark Documentation
Page 1418: pyspark Documentation
Page 1419: pyspark Documentation
Page 1420: pyspark Documentation
Page 1421: pyspark Documentation
Page 1422: pyspark Documentation
Page 1423: pyspark Documentation
Page 1424: pyspark Documentation
Page 1425: pyspark Documentation
Page 1426: pyspark Documentation
Page 1427: pyspark Documentation
Page 1428: pyspark Documentation
Page 1429: pyspark Documentation
Page 1430: pyspark Documentation
Page 1431: pyspark Documentation
Page 1432: pyspark Documentation
Page 1433: pyspark Documentation
Page 1434: pyspark Documentation
Page 1435: pyspark Documentation
Page 1436: pyspark Documentation
Page 1437: pyspark Documentation
Page 1438: pyspark Documentation
Page 1439: pyspark Documentation
Page 1440: pyspark Documentation
Page 1441: pyspark Documentation
Page 1442: pyspark Documentation
Page 1443: pyspark Documentation
Page 1444: pyspark Documentation
Page 1445: pyspark Documentation
Page 1446: pyspark Documentation
Page 1447: pyspark Documentation
Page 1448: pyspark Documentation