{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## This notebook is part of the Spark training delivered by CERN IT\n", "### Regression with spark.ml\n", "Contact: Luca.Canali@cern.ch\n", "\n", "This notebook is an implementation of a regression system trained using `spark.ml` to predict house prices.\n", "\n", "The data used for this exercise is the \"California Housing Prices dataset\" from the StatLib repository, originally featured in the following paper: Pace, R. Kelley, and Ronald Barry. \"Sparse spatial autoregressions.\" Statistics & Probability Letters 33.3 (1997): 291-297.\n", "The code and steps we follow in this notebook are inspired by the book \"Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, Aurelien Geron, 2nd Edition\".\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run this notebook from Jupyter with Python kernel\n", "- When using on CERN SWAN, do not attach the notebook to a Spark cluster, but rather run locally on the SWAN container\n", "- If running this outside CERN SWAN, plese make sure to tha PySpark installed: `pip install pyspark`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the Spark session and read the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#\n", "# Local mode: run this when using CERN SWAN not connected to a cluster \n", "# or run it on a private Jupyter notebook instance\n", "# Dependency: PySpark (use SWAN or pip install pyspark)\n", "#\n", "\n", "from pyspark.sql import SparkSession\n", "spark = SparkSession.builder \\\n", " .master(\"local[*]\") \\\n", " .appName(\"ML HandsOn Regression\") \\\n", " .config(\"spark.driver.memory\",\"4g\") \\\n", " .config(\"spark.ui.showConsoleProgress\", \"false\") \\\n", " .getOrCreate()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "

SparkSession - in-memory

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.3.1
\n", "
Master
\n", "
local[*]
\n", "
AppName
\n", "
ML HandsOn Regression
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "spark" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Local mode: read the data locally from the cloned repo\n", "df = (spark.read\n", " .format(\"csv\")\n", " .option(\"header\",\"true\")\n", " .option(\"inferschema\",\"true\")\n", " .load(\"../data/housing.csv.gz\")\n", " )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split data into a training and test datasets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16526" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train, test = df.randomSplit([0.8, 0.2], 4242)\n", "\n", "# cache the training dataset\n", "train.cache().count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic data exploration" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- longitude: double (nullable = true)\n", " |-- latitude: double (nullable = true)\n", " |-- housing_median_age: double (nullable = true)\n", " |-- total_rooms: double (nullable = true)\n", " |-- total_bedrooms: double (nullable = true)\n", " |-- population: double (nullable = true)\n", " |-- households: double (nullable = true)\n", " |-- median_income: double (nullable = true)\n", " |-- median_house_value: double (nullable = true)\n", " |-- ocean_proximity: string (nullable = true)\n", "\n" ] } ], "source": [ "train.printSchema()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
longitudelatitudehousing_median_agetotal_roomstotal_bedroomspopulationhouseholdsmedian_incomemedian_house_valueocean_proximity
0-124.3540.5452.01820.0300.0806.0270.03.014794600.0NEAR OCEAN
1-124.3041.8019.02672.0552.01298.0478.01.979785800.0NEAR OCEAN
2-124.3041.8417.02677.0531.01244.0456.03.0313103600.0NEAR OCEAN
3-124.2740.6936.02349.0528.01194.0465.02.517979000.0NEAR OCEAN
4-124.2640.5852.02217.0394.0907.0369.02.3571111400.0NEAR OCEAN
5-124.2540.2832.01430.0419.0434.0187.01.941776100.0NEAR OCEAN
6-124.2340.5452.02694.0453.01152.0435.03.0806106700.0NEAR OCEAN
7-124.2341.7511.03159.0616.01343.0479.02.480573200.0NEAR OCEAN
8-124.2241.7328.03003.0699.01530.0653.01.703878300.0NEAR OCEAN
9-124.2140.7532.01218.0331.0620.0268.01.652858100.0NEAR OCEAN
\n", "
" ], "text/plain": [ " longitude latitude housing_median_age total_rooms total_bedrooms \\\n", "0 -124.35 40.54 52.0 1820.0 300.0 \n", "1 -124.30 41.80 19.0 2672.0 552.0 \n", "2 -124.30 41.84 17.0 2677.0 531.0 \n", "3 -124.27 40.69 36.0 2349.0 528.0 \n", "4 -124.26 40.58 52.0 2217.0 394.0 \n", "5 -124.25 40.28 32.0 1430.0 419.0 \n", "6 -124.23 40.54 52.0 2694.0 453.0 \n", "7 -124.23 41.75 11.0 3159.0 616.0 \n", "8 -124.22 41.73 28.0 3003.0 699.0 \n", "9 -124.21 40.75 32.0 1218.0 331.0 \n", "\n", " population households median_income median_house_value ocean_proximity \n", "0 806.0 270.0 3.0147 94600.0 NEAR OCEAN \n", "1 1298.0 478.0 1.9797 85800.0 NEAR OCEAN \n", "2 1244.0 456.0 3.0313 103600.0 NEAR OCEAN \n", "3 1194.0 465.0 2.5179 79000.0 NEAR OCEAN \n", "4 907.0 369.0 2.3571 111400.0 NEAR OCEAN \n", "5 434.0 187.0 1.9417 76100.0 NEAR OCEAN \n", "6 1152.0 435.0 3.0806 106700.0 NEAR OCEAN \n", "7 1343.0 479.0 2.4805 73200.0 NEAR OCEAN \n", "8 1530.0 653.0 1.7038 78300.0 NEAR OCEAN \n", "9 620.0 268.0 1.6528 58100.0 NEAR OCEAN " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The dataset reports housing prices in California from 1990s\n", "train.limit(10).toPandas()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------------+--------+\n", "|ocean_proximity|count(1)|\n", "+---------------+--------+\n", "| ISLAND| 3|\n", "| NEAR OCEAN| 2133|\n", "| NEAR BAY| 1860|\n", "| <1H OCEAN| 7298|\n", "| INLAND| 5232|\n", "+---------------+--------+\n", "\n" ] } ], "source": [ "train.createOrReplaceTempView(\"train\")\n", "spark.sql(\"select ocean_proximity, count(*) from train group by ocean_proximity\").show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------+\n", "|count(1)|\n", "+--------+\n", "| 175|\n", "+--------+\n", "\n" ] } ], "source": [ "# the are some missing data in the total_bedrooms feature (i.e. there are null values)\n", "\n", "spark.sql(\"select count(*) from train where total_bedrooms is null\").show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature preparation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import StringIndexer,OneHotEncoder,VectorIndexer,Imputer,VectorAssembler, StandardScaler\n", "from pyspark.ml import Pipeline" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Transform ocean_proximity feature in a one-hot encoded feature \n", "ocean_index = StringIndexer(inputCol=\"ocean_proximity\",outputCol=\"indexed_ocean_proximity\")\n", "ocean_onehot = OneHotEncoder(inputCol=\"indexed_ocean_proximity\",outputCol=\"oh_ocean_proximity\",dropLast=False)\n", "\n", "# Add missing data to the total_bedrooms feature, by using estimation.\n", "imputer_tot_br = Imputer(strategy='median',inputCols=[\"total_bedrooms\"],outputCols=[\"total_bedrooms_filled\"])\n", "\n", "features = [\"longitude\", \"latitude\", \"housing_median_age\", \n", " \"total_rooms\", \"population\", \"households\", \n", " \"median_income\", \"total_bedrooms_filled\"]\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build a pipeline, bundling the feature preparation steps" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "feature_preparation_pipeline = Pipeline(stages=[ocean_index,ocean_onehot,imputer_tot_br])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
longitudelatitudehousing_median_agetotal_roomstotal_bedroomspopulationhouseholdsmedian_incomemedian_house_valueocean_proximityindexed_ocean_proximityoh_ocean_proximitytotal_bedrooms_filled
0-124.3540.5452.01820.0300.0806.0270.03.014794600.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)300.0
1-124.3041.8019.02672.0552.01298.0478.01.979785800.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)552.0
2-124.3041.8417.02677.0531.01244.0456.03.0313103600.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)531.0
3-124.2740.6936.02349.0528.01194.0465.02.517979000.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)528.0
4-124.2640.5852.02217.0394.0907.0369.02.3571111400.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)394.0
5-124.2540.2832.01430.0419.0434.0187.01.941776100.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)419.0
6-124.2340.5452.02694.0453.01152.0435.03.0806106700.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)453.0
7-124.2341.7511.03159.0616.01343.0479.02.480573200.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)616.0
8-124.2241.7328.03003.0699.01530.0653.01.703878300.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)699.0
9-124.2140.7532.01218.0331.0620.0268.01.652858100.0NEAR OCEAN2.0(0.0, 0.0, 1.0, 0.0, 0.0)331.0
\n", "
" ], "text/plain": [ " longitude latitude housing_median_age total_rooms total_bedrooms \\\n", "0 -124.35 40.54 52.0 1820.0 300.0 \n", "1 -124.30 41.80 19.0 2672.0 552.0 \n", "2 -124.30 41.84 17.0 2677.0 531.0 \n", "3 -124.27 40.69 36.0 2349.0 528.0 \n", "4 -124.26 40.58 52.0 2217.0 394.0 \n", "5 -124.25 40.28 32.0 1430.0 419.0 \n", "6 -124.23 40.54 52.0 2694.0 453.0 \n", "7 -124.23 41.75 11.0 3159.0 616.0 \n", "8 -124.22 41.73 28.0 3003.0 699.0 \n", "9 -124.21 40.75 32.0 1218.0 331.0 \n", "\n", " population households median_income median_house_value ocean_proximity \\\n", "0 806.0 270.0 3.0147 94600.0 NEAR OCEAN \n", "1 1298.0 478.0 1.9797 85800.0 NEAR OCEAN \n", "2 1244.0 456.0 3.0313 103600.0 NEAR OCEAN \n", "3 1194.0 465.0 2.5179 79000.0 NEAR OCEAN \n", "4 907.0 369.0 2.3571 111400.0 NEAR OCEAN \n", "5 434.0 187.0 1.9417 76100.0 NEAR OCEAN \n", "6 1152.0 435.0 3.0806 106700.0 NEAR OCEAN \n", "7 1343.0 479.0 2.4805 73200.0 NEAR OCEAN \n", "8 1530.0 653.0 1.7038 78300.0 NEAR OCEAN \n", "9 620.0 268.0 1.6528 58100.0 NEAR OCEAN \n", "\n", " indexed_ocean_proximity oh_ocean_proximity total_bedrooms_filled \n", "0 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 300.0 \n", "1 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 552.0 \n", "2 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 531.0 \n", "3 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 528.0 \n", "4 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 394.0 \n", "5 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 419.0 \n", "6 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 453.0 \n", "7 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 616.0 \n", "8 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 699.0 \n", "9 2.0 (0.0, 0.0, 1.0, 0.0, 0.0) 331.0 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit the feature preparation pipeline with trinaing data and show the \n", "feature_preparation_transformer = feature_preparation_pipeline.fit(train)\n", "\n", "# show a sample of data after feature preparation\n", "feature_preparation_transformer.transform(train).limit(10).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Further data preparation\n", "\n", "Vector assembler puts all data in a vector column. This step is required by the Spark ML algorithms. \n", "Standard scaler is a data preparation step. StandardScaler follows Standard Normal Distribution (SND). Therefore, it makes mean = 0 and scales the data to unit variance. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "assembler = VectorAssembler(inputCols=features, outputCol=\"unscaled_features\")\n", "\n", "std_scaler = StandardScaler(inputCol=\"unscaled_features\", outputCol=\"features\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "full_feature_preparation_pipeline = Pipeline(stages=[feature_preparation_pipeline,assembler,std_scaler])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
unscaled_featuresfeatures
0[-124.35, 40.54, 52.0, 1820.0, 806.0, 270.0, 3...[-62.03326688689804, 18.957877035296093, 4.137...
1[-124.3, 41.8, 19.0, 2672.0, 1298.0, 478.0, 1....[-62.00832387648916, 19.547095709802086, 1.511...
2[-124.3, 41.84, 17.0, 2677.0, 1244.0, 456.0, 3...[-62.00832387648916, 19.56580106454831, 1.3527...
3[-124.27, 40.69, 36.0, 2349.0, 1194.0, 465.0, ...[-61.99335807024382, 19.028022115594425, 2.864...
4[-124.26, 40.58, 52.0, 2217.0, 907.0, 369.0, 2...[-61.988369468162055, 18.976582390042314, 4.13...
5[-124.25, 40.28, 32.0, 1430.0, 434.0, 187.0, 1...[-61.983380866080275, 18.83629222944565, 2.546...
6[-124.23, 40.54, 52.0, 2694.0, 1152.0, 435.0, ...[-61.97340366191672, 18.957877035296093, 4.137...
7[-124.23, 41.75, 11.0, 3159.0, 1343.0, 479.0, ...[-61.97340366191672, 19.52371401636931, 0.8753...
8[-124.22, 41.73, 28.0, 3003.0, 1530.0, 653.0, ...[-61.96841505983494, 19.5143613389962, 2.22808...
9[-124.21, 40.75, 32.0, 1218.0, 620.0, 268.0, 1...[-61.96342645775316, 19.056080147713757, 2.546...
\n", "
" ], "text/plain": [ " unscaled_features \\\n", "0 [-124.35, 40.54, 52.0, 1820.0, 806.0, 270.0, 3... \n", "1 [-124.3, 41.8, 19.0, 2672.0, 1298.0, 478.0, 1.... \n", "2 [-124.3, 41.84, 17.0, 2677.0, 1244.0, 456.0, 3... \n", "3 [-124.27, 40.69, 36.0, 2349.0, 1194.0, 465.0, ... \n", "4 [-124.26, 40.58, 52.0, 2217.0, 907.0, 369.0, 2... \n", "5 [-124.25, 40.28, 32.0, 1430.0, 434.0, 187.0, 1... \n", "6 [-124.23, 40.54, 52.0, 2694.0, 1152.0, 435.0, ... \n", "7 [-124.23, 41.75, 11.0, 3159.0, 1343.0, 479.0, ... \n", "8 [-124.22, 41.73, 28.0, 3003.0, 1530.0, 653.0, ... \n", "9 [-124.21, 40.75, 32.0, 1218.0, 620.0, 268.0, 1... \n", "\n", " features \n", "0 [-62.03326688689804, 18.957877035296093, 4.137... \n", "1 [-62.00832387648916, 19.547095709802086, 1.511... \n", "2 [-62.00832387648916, 19.56580106454831, 1.3527... \n", "3 [-61.99335807024382, 19.028022115594425, 2.864... \n", "4 [-61.988369468162055, 18.976582390042314, 4.13... \n", "5 [-61.983380866080275, 18.83629222944565, 2.546... \n", "6 [-61.97340366191672, 18.957877035296093, 4.137... \n", "7 [-61.97340366191672, 19.52371401636931, 0.8753... \n", "8 [-61.96841505983494, 19.5143613389962, 2.22808... \n", "9 [-61.96342645775316, 19.056080147713757, 2.546... " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# this shows the results of data scaling\n", "full_feature_preparation_transformer = full_feature_preparation_pipeline.fit(train)\n", "\n", "full_feature_preparation_transformer.transform(train).select(\"unscaled_features\",\"features\").limit(10).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the model and assemble a pipeline" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.regression import GBTRegressor\n", "\n", "regressor = GBTRegressor(labelCol=\"median_house_value\", maxIter=40)\n", "\n", "pipeline = Pipeline(stages=[full_feature_preparation_pipeline, regressor])\n", "# this is equivalent to\n", "# pipeline = Pipeline(stages=[ocean_index, ocean_onehot, imputer_tot_br, assembler, std_scaler, regressor])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit the model using the training dataset" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# model training \n", "# this uses the pipeline built above\n", "# the pipeline puts together transformers and the model and is an estimator\n", "# we are going to fit it to the training data\n", "\n", "model = pipeline.fit(train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# the trained model can be saved on the filesystem\n", "\n", "model.save(\"myTrainedModel\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate the model performance on the test dataset by computing RMSE" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Root Mean Squared Error (RMSE) on test data = 54528.9\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-09-29 21:02:23,307 WARN netlib.InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS\n", "2022-09-29 21:02:23,320 WARN netlib.InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS\n" ] } ], "source": [ "from pyspark.ml.evaluation import RegressionEvaluator\n", "\n", "predictions = model.transform(test)\n", "dt_evaluator = RegressionEvaluator(\n", " labelCol=\"median_house_value\", predictionCol=\"prediction\", metricName=\"rmse\")\n", "rmse = dt_evaluator.evaluate(predictions)\n", "\n", "print(\"Root Mean Squared Error (RMSE) on test data = %g\" % rmse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Correlation Matrix\n", "The correlation matrix demonstrates the relationship between features. \n", "Correlation ranges from -1 to +1. Values closer to zero means there is no linear trend between the two variables. \n", "It is often displayed using a heatmap." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/cvmfs/sft.cern.ch/lcg/views/LCG_102swan/x86_64-centos7-gcc11-opt/python/pyspark/sql/context.py:125: FutureWarning: Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.\n", " warnings.warn(\n" ] } ], "source": [ "from pyspark.ml.stat import Correlation\n", "\n", "matrix = Correlation.corr(full_feature_preparation_transformer.transform(train).select('features'), 'features')\n", "matrix_np = matrix.collect()[0][\"pearson({})\".format('features')].values" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt \n", "\n", "matrix_np = matrix_np.reshape(8,8)\n", "\n", "fig, ax = plt.subplots(figsize=(12,8))\n", "ax = sns.heatmap(matrix_np, cmap=\"Blues\")\n", "ax.xaxis.set_ticklabels(features, rotation=270)\n", "ax.yaxis.set_ticklabels(features, rotation=0)\n", "ax.set_title(\"Correlation Matrix\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## An example of cross validation and grid search " ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "## This crossvalidation step takes several minutes, depending on the available cores\n", "\n", "from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n", "from pyspark.ml.evaluation import RegressionEvaluator\n", "\n", "\n", "paramGrid = ParamGridBuilder()\\\n", " .addGrid(regressor.maxIter, [100,50]) \\\n", " .baseOn({regressor.labelCol: \"median_house_value\"})\\\n", " .build()\n", "\n", "\n", "crossval = CrossValidator(estimator=pipeline,\n", " estimatorParamMaps=paramGrid,\n", " evaluator=RegressionEvaluator(labelCol=\"median_house_value\"),\n", " numFolds=4)\n", "cvModel=crossval.fit(train)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Root Mean Squared Error (RMSE) on test data = 51711.3\n" ] } ], "source": [ "from pyspark.ml.evaluation import RegressionEvaluator\n", "\n", "predictions = cvModel.transform(test)\n", "dt_evaluator = RegressionEvaluator(\n", " labelCol=\"median_house_value\", predictionCol=\"prediction\", metricName=\"rmse\")\n", "rmse = dt_evaluator.evaluate(predictions)\n", "\n", "print(\"Root Mean Squared Error (RMSE) on test data = %g\" % rmse)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "sparkconnect": { "bundled_options": [], "list_of_options": [] } }, "nbformat": 4, "nbformat_minor": 2 }