Predict Low Birth Weight Cases for Newborn Babies

Logistic Regression will be used to predict Low Birth Weight in infants using the remaining variables in the dataset. Low Birth Weight, the target variable, 1 = BWT <=2500g and 0 = BWT >2500g.

Low BWT Data

Dataset Source: Hosmer, D.W., Lemeshow, S. and Sturdivant, R.X. (2013) Applied Logistic Regression, 3rd ed., New York: Wiley

This dataset is also part of the aplore3 R package.

Variable Description
0 HT History of Hypertension (1: No, 2: Yes)
1 PTL History of premature labor (1: None, 2: One, 3: Two, etc)
2 Age Mother’s age in years
3 Race Race (1: White, 2: Black, 3: Other)
4 LOW (target) indicator of birth weight: (0: >= 2500, 1:\< 2500g)
5 UI Presence of Uterine irritability (1: No, 2: Yes)
6 FTV Number of physician visits during the first trimester (1: None, 2: One, 3: Two, etc)
7 Smoke Smoking status during pregnancy (1: No, 2: Yes)
8 BabyID Baby's unique identifier

Load Libraries

The Spark and Python libraries that you need are preinstalled in the notebook environment and only need to be loaded.

Run the following cell to load the libraries you will work with in this notebook:

In [1]:
# PySpark Machine Learning Library
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.sql import Row, SQLContext

import os
import sys
from pyspark import SparkConf
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.types import *

from pyspark.mllib.classification import LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
from numpy import array

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Library for confusion matrix, precision, test error
from pyspark.mllib.evaluation import MulticlassMetrics
# Library For Area under ROC curve and Area under precision-recall curve
from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Assign resources to the application
sqlContext = SQLContext(sc)

# packages for data analysis
import numpy as np
import pandas as pd
Waiting for a Spark session to start...
Spark Initialization Done! ApplicationId = app-20210226052557-0000
KERNEL_ID = b5fc5388-64e5-41a0-a188-c156cf69c119
In [2]:
# This is the summary of the data structure, including the column position and name.
# The first filed starts from position 0. 

# 0 ID      -  Identification code
# 1 LOW     -  Low birth weight (0: >= 2500 g, 1: < 2500 g), target variable
# 2 AGE     -  Mother's age in years
# 3 RACE    -  Race (1: White, 2: Black, 3: Other)
# 4 SMOKE   -  Smoking status during pregnancy (1: No, 2: Yes)
# 5 PTL     -  History of premature labor (1: None, 2: One, 3: Two, etc)
# 6 HT      -  History of hypertension (1: No, 2: Yes)
# 7 UI      -  Presence of Uterine irritability (1: No, 2: Yes)
# 8 FTV     -  Number of physician visits during the first trimester (1: None, 2: One, 3: Two, etc)

# Label is a target variable. PersonInfo is a list of independent variables besides unique identifier

LabeledDocument = Row("ID", "PersonInfo", "label")

# Define a function that parses the raw CSV file and returns an object of type LabeledDocument

def parseDocument(line):
    values = [str(x) for x in line.split(',')] 
    if (values[1]>'0'):
      LOW = 1.0
    else:
      LOW = 0.0
        
    textValue = str(values[2]) + " " + str(values[3])+ " " + str(values[4]) + str(values[5])+ " " + str(values[6]) + str(values[7])+ " " + str(values[8])
    return LabeledDocument(values[0], textValue, LOW)

Access Data

To read a file from Object Storage, you must setup the Spark configuration with your Object Storage credentials.

To do this, click on the cell below and select the Insert to code > Insert Spark Session DataFrame function from the Files tab below the data file you want to work with.

The following code contains the credentials for a file in your IBM Cloud Object Storage.
In [3]:
import ibmos2spark
# @hidden_cell
credentials = {
    'endpoint': 'https://s3-api.us-geo.objectstorage.service.networklayer.com',
    'service_id': 'iam-ServiceId-7c0fb81a-082a-4b1b-9f2e-eb2eb9bfbb3d',
    'iam_service_endpoint': 'https://iam.cloud.ibm.com/oidc/token',
    'api_key': '7ad67gk3JVD9viCagj_lTRN8_2Srxi-96QpN7e0PqDSs'
}

configuration_name = 'os_697aafb312964e4db161e65d0a2ee97e_configs'
cos = ibmos2spark.CloudObjectStorage(sc, credentials, configuration_name, 'bluemix_cos')

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df_data_1 = spark.read\
  .format('org.apache.spark.sql.execution.datasources.csv.CSVFileFormat')\
  .option('header', 'true')\
  .option('inferSchema', 'true')\
  .load(cos.url('lowbwt.csv', 'apachesparktutorial-donotdelete-pr-gzhhj3lfvqnwyl'))
df_data_1.take(5)
Out[3]:
[Row(ID=85, LOW=0, AGE=19, RACE=2, SMOKE=0, PTL=0, HT=0, UI=1, FTV=0),
 Row(ID=86, LOW=0, AGE=33, RACE=3, SMOKE=0, PTL=0, HT=0, UI=0, FTV=3),
 Row(ID=87, LOW=0, AGE=20, RACE=1, SMOKE=1, PTL=0, HT=0, UI=0, FTV=1),
 Row(ID=88, LOW=0, AGE=21, RACE=1, SMOKE=1, PTL=0, HT=0, UI=1, FTV=2),
 Row(ID=89, LOW=0, AGE=18, RACE=1, SMOKE=1, PTL=0, HT=0, UI=1, FTV=0)]

Explore Data

In [4]:
print('Number of cases', df_data_1.count())
Number of cases 189
In [10]:
#Number of babies born with weight less than 2.5kg and weight equal or above 2.5kg
df_data_1.groupby('LOW').count().show()
+---+-----+
|LOW|count|
+---+-----+
|  1|   59|
|  0|  130|
+---+-----+

In [38]:
#Low birth weight grouped by Race

import matplotlib.pyplot as plt
%matplotlib inline
df_data_1.crosstab('RACE', 'LOW').show()
df=df_data_1.crosstab('RACE', 'LOW').toPandas()
df.plot.bar(x="RACE_LOW", legend=True , title="Low Birth Weight by Mother's Race")
+--------+---+---+
|RACE_LOW|  0|  1|
+--------+---+---+
|       2| 15| 11|
|       1| 73| 23|
|       3| 42| 25|
+--------+---+---+

Out[38]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f1ef9027c50>
In [40]:
#Low birth weight grouped by number of doctor's visit

import matplotlib.pyplot as plt
%matplotlib inline
df_data_1.crosstab('FTV', 'LOW').show()
df=df_data_1.crosstab('FTV', 'LOW').toPandas()
df.plot.bar(x="FTV_LOW", legend=True , title="Low Birth Weight by Mother's regular visit to doctor")
+-------+---+---+
|FTV_LOW|  0|  1|
+-------+---+---+
|      0| 64| 36|
|      1| 36| 11|
|      6|  1|  0|
|      2| 23|  7|
|      3|  3|  4|
|      4|  3|  1|
+-------+---+---+

Out[40]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f1ef9123b50>
In [37]:
#Low birth weight grouped by smoking status
import matplotlib.pyplot as plt
%matplotlib inline
df_data_1.crosstab('SMOKE', 'LOW').show()
df=df_data_1.crosstab('SMOKE', 'LOW').toPandas()
df.plot.bar(x="SMOKE_LOW", legend=True , title="Low Birth Weight by Smoking Status")
+---------+---+---+
|SMOKE_LOW|  0|  1|
+---------+---+---+
|        1| 44| 30|
|        0| 86| 29|
+---------+---+---+

Out[37]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f1ef9040410>
In [17]:
# Average age by LOW outcome
# Convert Age data type from string to numeric.
pdf=df_data_1.toPandas()
pdf["AGE"]=pd.to_numeric(pdf.AGE)

df=sqlContext.createDataFrame(pdf)

PAge=df.groupby(['LOW'])\
.agg({"AGE": "AVG"}).toPandas()
PAge
Out[17]:
LOW avg(AGE)
0 0 23.661538
1 1 22.305085
In [18]:
#Age distribution for all mother
df.toPandas()["AGE"].plot.hist(x="Age", title="Age distribution")
Out[18]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f1efb421950>
In [19]:
# Age distribution for women having children with low birth weights by age.
df_died=df.select('AGE', 'LOW').filter(df['LOW']=='1')
pdf_died=df_died.select('AGE').toPandas()
pdf_died.plot.hist('AGE',  color='m',legend=False, title="Age distribution for women having children with low birth weights by age")
Out[19]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f1ef9a60b90>

Parse Data

Now let's load the data into a Spark RDD and output the number of rows and first 5 rows. Each project you create has a bucket in your object storage. You may get the bucket name from the project Settings page. Replace the string BUCKET to the bucket name

In [20]:
data = sc.textFile(cos.url('lowbwt.csv', 'apachesparktutorial-donotdelete-pr-gzhhj3lfvqnwyl'))
print ("Total records in the data set:", data.count())
print ("The first 5 rows")
data.take(5)
Total records in the data set: 190
The first 5 rows
Out[20]:
['ID,LOW,AGE,RACE,SMOKE,PTL,HT,UI,FTV',
 '85,0,19,2,0,0,0,1,0',
 '86,0,33,3,0,0,0,0,3',
 '87,0,20,1,1,0,0,0,1',
 '88,0,21,1,1,0,0,1,2']

Create DataFrame from RDD

In [21]:
#Parse and Load the data into a dataframe. The code calls the parsing function defined above
documents = data.filter(lambda s: "Name" not in s).map(parseDocument)
lowbwtData = documents.toDF() # ToDataFrame
print ("Number of records: " + str(lowbwtData.count()))
print ( "First 5 records: ")
lowbwtData.take(5)
Number of records: 190
First 5 records: 
Out[21]:
[Row(ID='ID', PersonInfo='AGE RACE SMOKEPTL HTUI FTV', label=1.0),
 Row(ID='85', PersonInfo='19 2 00 01 0', label=0.0),
 Row(ID='86', PersonInfo='33 3 00 00 3', label=0.0),
 Row(ID='87', PersonInfo='20 1 10 00 1', label=0.0),
 Row(ID='88', PersonInfo='21 1 10 01 2', label=0.0)]

Split Data into Training and Test set

We divide the data into training and test set. The training set is used to build the model to be used on future data, and the test set is used to evaluate the model.

In [25]:
# Divide the data into training and test set, with random seed to reproduce results
(train, test) = lowbwtData.randomSplit([0.6, 0.4], seed = 123)
print ("Number of records in the training set: " + str(train.count()))
print ("Number of records in the test set: " + str(test.count()))
# Output first 20 records in the training set
print ("First 20 records in the training set: ")
train.show()
Number of records in the training set: 109
Number of records in the test set: 81
First 20 records in the training set: 
+---+------------+-----+
| ID|  PersonInfo|label|
+---+------------+-----+
|100|18 1 10 00 0|  0.0|
|103|25 1 10 00 3|  0.0|
|104|20 3 00 01 0|  0.0|
|105|28 1 10 00 1|  0.0|
|107|31 1 00 01 3|  0.0|
|108|36 1 00 00 1|  0.0|
|111|25 3 00 01 2|  0.0|
|112|28 1 00 00 0|  0.0|
|115|26 2 10 00 0|  0.0|
|116|17 2 00 00 1|  0.0|
|117|17 2 00 00 1|  0.0|
|118|24 1 11 00 1|  0.0|
|120|25 1 00 00 1|  0.0|
|121|25 2 00 00 0|  0.0|
|123|29 1 10 00 2|  0.0|
|126|31 1 10 00 2|  0.0|
|127|33 1 10 00 1|  0.0|
|128|21 2 10 00 2|  0.0|
|130|23 2 00 00 1|  0.0|
|133|18 1 10 01 0|  0.0|
+---+------------+-----+
only showing top 20 rows

Build Logistic Regression Model

We use the Pipeline of SparkML to build the Logistic Regression Model

In [26]:
# set up Logistic Regression using Pipeline of SparkML
tokenizer = Tokenizer(inputCol="PersonInfo", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=50, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
In [27]:
# set up Logistic Regression Model
# the stages are executed in order
model = pipeline.fit(train)
#[stage.coefficients for stage in model.stages if hasattr(stage, "coefficients")]
# model.stages[2].intercept
model.stages[2].coefficients
Out[27]:
SparseVector(262144, {4914: 0.5012, 12524: 0.6831, 15544: -3.2286, 25008: 1.4439, 44189: -3.3237, 44445: 0.0885, 46045: -0.4753, 52525: -1.2979, 73702: 1.4439, 76985: 0.1023, 87466: -0.2963, 87985: 0.2905, 90804: 0.2595, 92651: -0.608, 94522: -0.0995, 103775: 4.0311, 110130: 0.5996, 124674: -2.4578, 136198: 6.9336, 137399: -0.0243, 145674: 0.7275, 147703: 1.4439, 155537: 0.7721, 166737: -0.6849, 168211: -2.6592, 168590: 0.7616, 177679: 0.6976, 190344: 6.2575, 190727: -5.2239, 191897: 1.4439, 194395: -2.9711, 194802: -4.7155, 217461: 4.6765, 226777: 0.9496, 227686: 1.4439, 231799: -4.4233, 234238: 6.5979, 238630: -3.3863, 243381: -2.0553, 247409: 0.8427, 254165: 0.2084, 257750: -1.1924})

Predict for Test data

In [28]:
# Make predictions for test data and print columns of interest
prediction = model.transform(test)
selected = prediction.select("PersonInfo", "prediction", "probability")
for row in selected.collect():
    print (row)
Row(PersonInfo='18 1 10 00 0', prediction=0.0, probability=DenseVector([0.9743, 0.0257]))
Row(PersonInfo='15 2 00 00 0', prediction=1.0, probability=DenseVector([0.0034, 0.9966]))
Row(PersonInfo='32 3 00 00 2', prediction=1.0, probability=DenseVector([0.0025, 0.9975]))
Row(PersonInfo='28 3 00 00 0', prediction=0.0, probability=DenseVector([0.9881, 0.0119]))
Row(PersonInfo='17 1 10 00 0', prediction=0.0, probability=DenseVector([0.7076, 0.2924]))
Row(PersonInfo='29 1 00 00 2', prediction=0.0, probability=DenseVector([0.8138, 0.1862]))
Row(PersonInfo='35 2 11 00 1', prediction=0.0, probability=DenseVector([0.6419, 0.3581]))
Row(PersonInfo='19 1 10 00 2', prediction=0.0, probability=DenseVector([0.5569, 0.4431]))
Row(PersonInfo='27 1 10 00 0', prediction=1.0, probability=DenseVector([0.0092, 0.9908]))
Row(PersonInfo='19 1 00 00 2', prediction=0.0, probability=DenseVector([0.7692, 0.2308]))
Row(PersonInfo='21 1 00 00 0', prediction=0.0, probability=DenseVector([0.9455, 0.0545]))
Row(PersonInfo='18 1 10 01 0', prediction=0.0, probability=DenseVector([0.9395, 0.0605]))
Row(PersonInfo='32 1 00 00 4', prediction=0.0, probability=DenseVector([0.685, 0.315]))
Row(PersonInfo='19 3 00 00 0', prediction=0.0, probability=DenseVector([0.6023, 0.3977]))
Row(PersonInfo='24 1 00 00 2', prediction=0.0, probability=DenseVector([0.8885, 0.1115]))
Row(PersonInfo='22 3 10 00 0', prediction=0.0, probability=DenseVector([0.819, 0.181]))
Row(PersonInfo='22 1 10 00 0', prediction=0.0, probability=DenseVector([0.9468, 0.0532]))
Row(PersonInfo='19 3 00 00 0', prediction=0.0, probability=DenseVector([0.6023, 0.3977]))
Row(PersonInfo='21 3 10 01 0', prediction=1.0, probability=DenseVector([0.4049, 0.5951]))
Row(PersonInfo='17 3 00 00 0', prediction=0.0, probability=DenseVector([0.62, 0.38]))
Row(PersonInfo='17 3 00 00 0', prediction=0.0, probability=DenseVector([0.62, 0.38]))
Row(PersonInfo='23 3 00 00 2', prediction=0.0, probability=DenseVector([0.6267, 0.3733]))
Row(PersonInfo='24 3 00 00 2', prediction=0.0, probability=DenseVector([0.6695, 0.3305]))
Row(PersonInfo='22 2 01 00 2', prediction=0.0, probability=DenseVector([0.7483, 0.2517]))
Row(PersonInfo='31 3 10 00 2', prediction=0.0, probability=DenseVector([0.983, 0.017]))
Row(PersonInfo='23 3 10 00 1', prediction=0.0, probability=DenseVector([0.6971, 0.3029]))
Row(PersonInfo='16 1 10 00 0', prediction=0.0, probability=DenseVector([0.7014, 0.2986]))
Row(PersonInfo='18 2 00 00 0', prediction=0.0, probability=DenseVector([0.9651, 0.0349]))
Row(PersonInfo='32 1 11 00 4', prediction=1.0, probability=DenseVector([0.0191, 0.9809]))
Row(PersonInfo='20 2 10 00 0', prediction=0.0, probability=DenseVector([0.5077, 0.4923]))
Row(PersonInfo='32 1 00 00 0', prediction=1.0, probability=DenseVector([0.0173, 0.9827]))
Row(PersonInfo='19 3 00 00 0', prediction=0.0, probability=DenseVector([0.6023, 0.3977]))
Row(PersonInfo='23 1 00 00 0', prediction=0.0, probability=DenseVector([0.9219, 0.0781]))
Row(PersonInfo='21 3 00 00 2', prediction=0.0, probability=DenseVector([0.7115, 0.2885]))
Row(PersonInfo='19 2 00 01 0', prediction=1.0, probability=DenseVector([0.4014, 0.5986]))
Row(PersonInfo='33 3 00 00 3', prediction=0.0, probability=DenseVector([0.9804, 0.0196]))
Row(PersonInfo='22 1 00 00 1', prediction=0.0, probability=DenseVector([0.9897, 0.0103]))
Row(PersonInfo='17 3 00 00 1', prediction=0.0, probability=DenseVector([0.7685, 0.2315]))
Row(PersonInfo='30 3 01 01 2', prediction=0.0, probability=DenseVector([0.5031, 0.4969]))
Row(PersonInfo='25 3 01 10 0', prediction=1.0, probability=DenseVector([0.1637, 0.8363]))
Row(PersonInfo='23 3 00 01 1', prediction=0.0, probability=DenseVector([0.7142, 0.2858]))
Row(PersonInfo='24 2 01 00 1', prediction=0.0, probability=DenseVector([0.7654, 0.2346]))
Row(PersonInfo='16 1 10 00 0', prediction=0.0, probability=DenseVector([0.7014, 0.2986]))
Row(PersonInfo='19 1 10 00 0', prediction=0.0, probability=DenseVector([0.692, 0.308]))
Row(PersonInfo='30 1 00 00 1', prediction=0.0, probability=DenseVector([0.9886, 0.0114]))
Row(PersonInfo='19 1 10 10 0', prediction=1.0, probability=DenseVector([0.4586, 0.5414]))
Row(PersonInfo='21 1 10 10 1', prediction=0.0, probability=DenseVector([0.8338, 0.1662]))
Row(PersonInfo='25 2 00 10 0', prediction=1.0, probability=DenseVector([0.3409, 0.6591]))
Row(PersonInfo='22 1 00 00 0', prediction=0.0, probability=DenseVector([0.9793, 0.0207]))
Row(PersonInfo='32 1 00 00 2', prediction=1.0, probability=DenseVector([0.0097, 0.9903]))
Row(PersonInfo='29 1 10 00 2', prediction=0.0, probability=DenseVector([0.6223, 0.3777]))
Row(PersonInfo='33 1 00 01 1', prediction=0.0, probability=DenseVector([0.9969, 0.0031]))
Row(PersonInfo='28 3 00 00 1', prediction=0.0, probability=DenseVector([0.9941, 0.0059]))
Row(PersonInfo='28 3 00 00 0', prediction=0.0, probability=DenseVector([0.9881, 0.0119]))
Row(PersonInfo='16 3 00 00 1', prediction=0.0, probability=DenseVector([0.7632, 0.2368]))
Row(PersonInfo='20 1 00 00 1', prediction=0.0, probability=DenseVector([0.9529, 0.0471]))
Row(PersonInfo='26 3 00 00 0', prediction=0.0, probability=DenseVector([0.6428, 0.3572]))
Row(PersonInfo='21 1 00 00 1', prediction=0.0, probability=DenseVector([0.9724, 0.0276]))
Row(PersonInfo='22 1 00 00 0', prediction=0.0, probability=DenseVector([0.9793, 0.0207]))
Row(PersonInfo='31 1 00 00 2', prediction=0.0, probability=DenseVector([0.9983, 0.0017]))
Row(PersonInfo='24 1 00 00 1', prediction=0.0, probability=DenseVector([0.9666, 0.0334]))
Row(PersonInfo='24 1 11 00 0', prediction=1.0, probability=DenseVector([0.1131, 0.8869]))
Row(PersonInfo='21 3 00 00 0', prediction=0.0, probability=DenseVector([0.8151, 0.1849]))
Row(PersonInfo='20 3 00 01 0', prediction=0.0, probability=DenseVector([0.5087, 0.4913]))
Row(PersonInfo='26 1 11 00 0', prediction=1.0, probability=DenseVector([0.0596, 0.9404]))
Row(PersonInfo='28 3 11 01 0', prediction=1.0, probability=DenseVector([0.2327, 0.7673]))
Row(PersonInfo='27 2 00 01 0', prediction=1.0, probability=DenseVector([0.0028, 0.9972]))
Row(PersonInfo='17 1 10 00 0', prediction=0.0, probability=DenseVector([0.7076, 0.2924]))
Row(PersonInfo='18 3 00 00 0', prediction=0.0, probability=DenseVector([0.9624, 0.0376]))
Row(PersonInfo='21 3 01 00 4', prediction=0.0, probability=DenseVector([0.9955, 0.0045]))
Row(PersonInfo='31 1 11 00 1', prediction=0.0, probability=DenseVector([0.9514, 0.0486]))
Row(PersonInfo='23 2 10 00 1', prediction=0.0, probability=DenseVector([0.7135, 0.2865]))
Row(PersonInfo='24 2 10 00 0', prediction=0.0, probability=DenseVector([0.5963, 0.4037]))
Row(PersonInfo='22 1 10 00 1', prediction=0.0, probability=DenseVector([0.9731, 0.0269]))
Row(PersonInfo='23 1 11 00 0', prediction=1.0, probability=DenseVector([0.0956, 0.9044]))
Row(PersonInfo='26 3 01 10 1', prediction=1.0, probability=DenseVector([0.3611, 0.6389]))
Row(PersonInfo='20 3 00 00 3', prediction=0.0, probability=DenseVector([0.5667, 0.4333]))
Row(PersonInfo='28 1 10 00 2', prediction=0.0, probability=DenseVector([0.9856, 0.0144]))
Row(PersonInfo='14 3 00 00 2', prediction=0.0, probability=DenseVector([0.7468, 0.2532]))
Row(PersonInfo='23 3 10 00 0', prediction=0.0, probability=DenseVector([0.5308, 0.4692]))
Row(PersonInfo='17 2 00 10 0', prediction=1.0, probability=DenseVector([0.3996, 0.6004]))
In [29]:
#Tabulate the predicted outcome
prediction.select("prediction").groupBy("prediction").count().show(truncate=False)
+----------+-----+
|prediction|count|
+----------+-----+
|0.0       |63   |
|1.0       |18   |
+----------+-----+

In [30]:
#Tabulate the actual outcome
prediction.select("label").groupBy("label").count().show(truncate=False)
+-----+-----+
|label|count|
+-----+-----+
|0.0  |57   |
|1.0  |24   |
+-----+-----+

In [31]:
# This table shows:
# 1. The number of low birth weight infants predicted as having low birth weight
# 2. The number of low birth weight infants predicted as not having low birth weight
# 3. The number of regular birth weight infants predicted as having low birth weight
# 4. The number of regular birth weight infants predicted as not having low birth weight

prediction.crosstab('label', 'prediction').show()
+----------------+---+---+
|label_prediction|0.0|1.0|
+----------------+---+---+
|             1.0| 16|  8|
|             0.0| 47| 10|
+----------------+---+---+

Evaluate the Model

We evaluate the model on a training set and on a test set. The purpose is to measure the model's predictive accuracy, including the accuracy for new data.

In [32]:
# Evaluate the Logistic Regression model on a training set
# Select (prediction, true label) and compute test error
pred_lr=model.transform(train).select("prediction", "label")
eval_lr=MulticlassClassificationEvaluator (
    labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy_lr=eval_lr.evaluate(pred_lr)
# create RDD
predictionAndLabels_lr=pred_lr.rdd
metrics_lr=MulticlassMetrics(predictionAndLabels_lr)
precision_lr=metrics_lr.precision(1.0)
recall_lr=metrics_lr.recall(1.0)
f1Measure_lr = metrics_lr.fMeasure(1.0, 1.0)
print ("Model evaluation for the training data")
print ("Accuracy = %s" %accuracy_lr)
print ("Error = %s" % (1-accuracy_lr))
print ("Precision = %s" %precision_lr)
print ("Recall = %s" %recall_lr)
print("F1 Measure = %s" % f1Measure_lr)
Model evaluation for the training data
Accuracy = 0.7614678899082569
Error = 0.23853211009174313
Precision = 0.6923076923076923
Recall = 0.5
F1 Measure = 0.5806451612903226
In [33]:
# Evaluate the Logistic Regression model on a test set
# Select (prediction, true label) and compute test error
pred_lr=model.transform(test).select("prediction", "label")
eval_lr=MulticlassClassificationEvaluator (
    labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy_lr=eval_lr.evaluate(pred_lr)
# create RDD
predictionAndLabels_lr=pred_lr.rdd
metrics_lr=MulticlassMetrics(predictionAndLabels_lr)
precision_lr=metrics_lr.precision(1.0)
recall_lr=metrics_lr.recall(1.0)
f1Measure_lr = metrics_lr.fMeasure(1.0, 1.0)
print ("Model evaluation for the test data")
print ("Accuracy = %s" %accuracy_lr)
print ("Error = %s" % (1-accuracy_lr))
print ("Precision = %s" %precision_lr)
print ("Recall = %s" %recall_lr)
print ("F1 Measure = %s" % f1Measure_lr)
Model evaluation for the test data
Accuracy = 0.6790123456790124
Error = 0.32098765432098764
Precision = 0.4444444444444444
Recall = 0.3333333333333333
F1 Measure = 0.380952380952381
In [34]:
bin_lr=BinaryClassificationMetrics(predictionAndLabels_lr)

# Area under precision-recall curve
print("Area under PR = %s" % bin_lr.areaUnderPR)
# Area under Receiver operating characteristic curve
print("Area under ROC = %s" % bin_lr.areaUnderROC)
Area under PR = 0.3950617283950617
Area under ROC = 0.5789473684210527

ROC Curve Plot

The Binary Logistic Regression method returns the pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary object. We may use the roc method to obtain the coordinates of the points on Receiver operating characteristic (ROC) curve. The coordinates are stored in a Dataframe with two fields FPR=false positive rate and TPR=true positive rate

In [35]:
#The firt 20 ROC curve points
model.stages[2].summary.roc.show()
+------------------+--------------------+
|               FPR|                 TPR|
+------------------+--------------------+
|               0.0|                 0.0|
|               0.0|0.027777777777777776|
|               0.0| 0.05555555555555555|
|               0.0| 0.08333333333333333|
|               0.0|  0.1111111111111111|
|               0.0|  0.1388888888888889|
|               0.0| 0.16666666666666666|
|               0.0| 0.19444444444444445|
|               0.0|  0.2222222222222222|
|               0.0|                0.25|
|               0.0|  0.2777777777777778|
|0.0136986301369863|  0.2777777777777778|
|0.0273972602739726|  0.2777777777777778|
|0.0273972602739726|  0.3055555555555556|
|0.0273972602739726|  0.3333333333333333|
|0.0273972602739726|  0.3611111111111111|
|0.0273972602739726|  0.3888888888888889|
|0.0273972602739726|  0.4166666666666667|
|0.0273972602739726|  0.4444444444444444|
|0.0410958904109589|  0.4444444444444444|
+------------------+--------------------+
only showing top 20 rows

In [36]:
import matplotlib.pyplot as plt
%matplotlib inline
ROC=model.stages[2].summary.roc
df=ROC.toPandas()

df.plot(x='FPR', y='TPR', legend=False)
Out[36]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f1ef913e8d0>
In [ ]: