Using XGBoost in Sagemaker (Batch Transform) - High level API

20 minute read

Using XGBoost in SageMaker (Batch Transform)


As an introduction to using SageMaker’s High Level Python API we will look at a relatively simple problem. Namely, we will use the Boston Housing Dataset to predict the median value of a home in the area of Boston Mass.

The documentation for the high level API can be found on the ReadTheDocs page

General Outline

Typically, when using a notebook instance with SageMaker, you will proceed through the following steps. Of course, not every step will need to be done with each project. Also, there is quite a lot of room for variation in many of the steps, as you will see throughout these lessons.

  1. Download or otherwise retrieve the data.
  2. Process / Prepare the data.
  3. Upload the processed data to S3.
  4. Train a chosen model.
  5. Test the trained model (typically using a batch transform job).
  6. Deploy the trained model.
  7. Use the deployed model.

In this post we will only be covering steps 1 through 5 as we just want to get a feel for using SageMaker. In later posts we will talk about deploying a trained model in much more detail.

Step 0: Setting up the notebook

We begin by setting up all of the necessary bits required to run our notebook. To start that means loading all of the Python modules we will need.

%matplotlib inline

import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from sklearn.datasets import load_boston
import sklearn.model_selection

In addition to the modules above, we need to import the various bits of SageMaker that we will be using.

import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer

# This is an object that represents the SageMaker session that we are currently operating in. This
# object contains some useful information that we will need to access later such as our region.
session = sagemaker.Session()

# This is an object that represents the IAM role that we are currently assigned. When we construct
# and launch the training job later we will need to tell it what IAM role it should have. Since our
# use case is relatively simple we will simply assign the training job the role we currently have.
role = get_execution_role()

NOTE: Install this exact version if you are using the code in this post. Otherwise, you might run into some issues.

print(sagemaker.__version__)
1.72.0

Step 1: Downloading the data

Fortunately, this dataset can be retrieved using sklearn and so this step is relatively straightforward.

boston = load_boston()

Step 2: Preparing and splitting the data

Given that this is clean tabular data, we don’t need to do any processing. However, we do need to split the rows in the dataset up into train, test and validation sets.

# First we package up the input data and the target variable (the median value) as pandas dataframes. This
# will make saving the data to a file a little easier later on.

X_bos_pd = pd.DataFrame(boston.data, columns=boston.feature_names)
Y_bos_pd = pd.DataFrame(boston.target)

# We split the dataset into 2/3 training and 1/3 testing sets.
X_train, X_test, Y_train, Y_test = sklearn.model_selection.train_test_split(X_bos_pd, Y_bos_pd, test_size=0.33)

# Then we split the training set further into 2/3 training and 1/3 validation sets.
X_train, X_val, Y_train, Y_val = sklearn.model_selection.train_test_split(X_train, Y_train, test_size=0.33)

Step 3: Uploading the data files to S3

When a training job is constructed using SageMaker, a container is executed which performs the training operation. This container is given access to data that is stored in S3. This means that we need to upload the data we want to use for training to S3. In addition, when we perform a batch transform job, SageMaker expects the input data to be stored on S3. We can use the SageMaker API to do this and hide some of the details.

Save the data locally

First we need to create the test, train and validation csv files which we will then upload to S3.

# This is our local data directory. We need to make sure that it exists.
data_dir = '../data/boston'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
# We use pandas to save our test, train and validation data to csv files. Note that we make sure not to include header
# information or an index as this is required by the built in algorithms provided by Amazon. Also, for the train and
# validation data, it is assumed that the first entry in each row is the target variable.

X_test.to_csv(os.path.join(data_dir, 'test.csv'), header=False, index=False)

pd.concat([Y_val, X_val], axis=1).to_csv(os.path.join(data_dir, 'validation.csv'), header=False, index=False)
pd.concat([Y_train, X_train], axis=1).to_csv(os.path.join(data_dir, 'train.csv'), header=False, index=False)

Upload to S3

Since we are currently running inside of a SageMaker session, we can use the object which represents this session to upload our data to the ‘default’ S3 bucket. Note that it is good practice to provide a custom prefix (essentially an S3 folder) to make sure that you don’t accidentally interfere with data uploaded from some other notebook or project.

prefix = 'sk-boston-xgboost-HL'

test_location = session.upload_data(os.path.join(data_dir, 'test.csv'), key_prefix=prefix)
val_location = session.upload_data(os.path.join(data_dir, 'validation.csv'), key_prefix=prefix)
train_location = session.upload_data(os.path.join(data_dir, 'train.csv'), key_prefix=prefix)

Step 4: Train the XGBoost model

Now that we have the training and validation data uploaded to S3, we can construct our XGBoost model and train it. We will be making use of the high level SageMaker API to do this which will make the resulting code a little easier to read at the cost of some flexibility.

To construct an estimator, the object which we wish to train, we need to provide the location of a container which contains the training code. Since we are using a built in algorithm this container is provided by Amazon. However, the full name of the container is a bit lengthy and depends on the region that we are operating in. Fortunately, SageMaker provides a useful utility method called get_image_uri that constructs the image name for us.

To use the get_image_uri method we need to provide it with our current region, which can be obtained from the session object, and the name of the algorithm we wish to use. In this notebook we will be using XGBoost however you could try another algorithm if you wish. The list of built in algorithms can be found in the list of Common Parameters.

# As stated above, we use this utility method to construct the image name for the training container.
container = get_image_uri(session.boto_region_name, 'xgboost')

# Now that we know which container to use, we can construct the estimator object.
xgb = sagemaker.estimator.Estimator(container, # The image name of the training container
                                    role,      # The IAM role to use (our current role in this case)
                                    train_instance_count=1, # The number of instances to use for training
                                    train_instance_type='ml.m4.xlarge', # The type of instance to use for training
                                    output_path='s3://{}/{}/output'.format(session.default_bucket(), prefix),
                                                                        # Where to save the output (the model artifacts)
                                    sagemaker_session=session) # The current SageMaker session
'get_image_uri' method will be deprecated in favor of 'ImageURIProvider' class in SageMaker Python SDK v2.
There is a more up to date SageMaker XGBoost image. To use the newer image, please set 'repo_version'='1.0-1'. For example:
	get_image_uri(region, 'xgboost', '1.0-1').
Parameter image_name will be renamed to image_uri in SageMaker Python SDK v2.

Before asking SageMaker to begin the training job, we should probably set any model specific hyperparameters. There are quite a few that can be set when using the XGBoost algorithm, below are just a few of them. If you would like to change the hyperparameters below or modify additional ones you can find additional information on the XGBoost hyperparameter page

xgb.set_hyperparameters(max_depth=5,
                        eta=0.2,
                        gamma=4,
                        min_child_weight=6,
                        subsample=0.8,
                        objective='reg:linear',
                        early_stopping_rounds=10,
                        num_round=200)

Now that we have our estimator object completely set up, it is time to train it. To do this we make sure that SageMaker knows our input data is in csv format and then execute the fit method.

# This is a wrapper around the location of our train and validation data, to make sure that SageMaker
# knows our data is in csv format.
s3_input_train = sagemaker.s3_input(s3_data=train_location, content_type='csv')
s3_input_validation = sagemaker.s3_input(s3_data=val_location, content_type='csv')

xgb.fit({'train': s3_input_train, 'validation': s3_input_validation})
's3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.
's3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.


2020-12-01 20:31:06 Starting - Starting the training job...
2020-12-01 20:31:08 Starting - Launching requested ML instances......
2020-12-01 20:32:21 Starting - Preparing the instances for training......
2020-12-01 20:33:29 Downloading - Downloading input data...
2020-12-01 20:34:00 Training - Downloading the training image..
2020-12-01 20:34:33 Uploading - Uploading generated training model
2020-12-01 20:34:33 Completed - Training job completed
Arguments: train
[2020-12-01:20:34:22:INFO] Running standalone xgboost training.
[2020-12-01:20:34:22:INFO] File size need to be processed in the node: 0.02mb. Available memory size in the node: 8458.9mb
[2020-12-01:20:34:22:INFO] Determined delimiter of CSV input is ','
[20:34:22] S3DistributionType set as FullyReplicated
[20:34:22] 227x13 matrix with 2951 entries loaded from /opt/ml/input/data/train?format=csv&label_column=0&delimiter=,
[2020-12-01:20:34:22:INFO] Determined delimiter of CSV input is ','
[20:34:22] S3DistributionType set as FullyReplicated
[20:34:22] 112x13 matrix with 1456 entries loaded from /opt/ml/input/data/validation?format=csv&label_column=0&delimiter=,
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 0 pruned nodes, max_depth=3
[0]#011train-rmse:20.0336#011validation-rmse:19.247
Multiple eval metrics have been passed: 'validation-rmse' will be used for early stopping.

Will train until validation-rmse hasn't improved in 10 rounds.
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 0 pruned nodes, max_depth=4
[1]#011train-rmse:16.3345#011validation-rmse:15.8973
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 0 pruned nodes, max_depth=3
[2]#011train-rmse:13.3363#011validation-rmse:13.2384
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 0 pruned nodes, max_depth=4
[3]#011train-rmse:10.977#011validation-rmse:11.2375
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=4
[4]#011train-rmse:9.04805#011validation-rmse:9.65843
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=5
[5]#011train-rmse:7.57107#011validation-rmse:8.53022
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 4 pruned nodes, max_depth=5
[6]#011train-rmse:6.37385#011validation-rmse:7.54467
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 26 extra nodes, 2 pruned nodes, max_depth=5
[7]#011train-rmse:5.39017#011validation-rmse:6.76722
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 24 extra nodes, 2 pruned nodes, max_depth=5
[8]#011train-rmse:4.61565#011validation-rmse:6.24955
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 20 extra nodes, 4 pruned nodes, max_depth=5
[9]#011train-rmse:4.10733#011validation-rmse:5.94645
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 0 pruned nodes, max_depth=5
[10]#011train-rmse:3.61967#011validation-rmse:5.64428
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 4 pruned nodes, max_depth=5
[11]#011train-rmse:3.22507#011validation-rmse:5.39222
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 4 pruned nodes, max_depth=5
[12]#011train-rmse:2.97609#011validation-rmse:5.31567
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 20 extra nodes, 0 pruned nodes, max_depth=5
[13]#011train-rmse:2.76481#011validation-rmse:5.21326
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 30 extra nodes, 0 pruned nodes, max_depth=5
[14]#011train-rmse:2.53945#011validation-rmse:5.08828
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 2 pruned nodes, max_depth=5
[15]#011train-rmse:2.41408#011validation-rmse:5.01763
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=5
[16]#011train-rmse:2.30863#011validation-rmse:4.96369
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 0 pruned nodes, max_depth=5
[17]#011train-rmse:2.19086#011validation-rmse:4.8799
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 22 extra nodes, 2 pruned nodes, max_depth=5
[18]#011train-rmse:2.0997#011validation-rmse:4.86096
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 10 pruned nodes, max_depth=5
[19]#011train-rmse:2.04434#011validation-rmse:4.8503
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 24 extra nodes, 4 pruned nodes, max_depth=5
[20]#011train-rmse:1.95575#011validation-rmse:4.86373
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 0 pruned nodes, max_depth=5
[21]#011train-rmse:1.91385#011validation-rmse:4.85857
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 4 pruned nodes, max_depth=5
[22]#011train-rmse:1.88345#011validation-rmse:4.85149
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 2 pruned nodes, max_depth=5
[23]#011train-rmse:1.8152#011validation-rmse:4.83242
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 0 pruned nodes, max_depth=5
[24]#011train-rmse:1.79216#011validation-rmse:4.83176
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 8 pruned nodes, max_depth=5
[25]#011train-rmse:1.74466#011validation-rmse:4.805
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 0 pruned nodes, max_depth=5
[26]#011train-rmse:1.72415#011validation-rmse:4.80736
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 20 extra nodes, 2 pruned nodes, max_depth=5
[27]#011train-rmse:1.67875#011validation-rmse:4.79914
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 2 pruned nodes, max_depth=5
[28]#011train-rmse:1.64787#011validation-rmse:4.78477
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 22 extra nodes, 4 pruned nodes, max_depth=5
[29]#011train-rmse:1.6094#011validation-rmse:4.7747
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 0 pruned nodes, max_depth=5
[30]#011train-rmse:1.58964#011validation-rmse:4.74697
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 4 pruned nodes, max_depth=5
[31]#011train-rmse:1.55695#011validation-rmse:4.72472
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 2 pruned nodes, max_depth=5
[32]#011train-rmse:1.53066#011validation-rmse:4.71271
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 2 pruned nodes, max_depth=5
[33]#011train-rmse:1.52446#011validation-rmse:4.71454
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 8 pruned nodes, max_depth=5
[34]#011train-rmse:1.51085#011validation-rmse:4.72342
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 22 extra nodes, 6 pruned nodes, max_depth=5
[35]#011train-rmse:1.46318#011validation-rmse:4.68268
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 6 pruned nodes, max_depth=5
[36]#011train-rmse:1.43235#011validation-rmse:4.67333
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 4 pruned nodes, max_depth=5
[37]#011train-rmse:1.40406#011validation-rmse:4.64896
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 4 pruned nodes, max_depth=3
[38]#011train-rmse:1.39188#011validation-rmse:4.63862
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 14 pruned nodes, max_depth=4
[39]#011train-rmse:1.36541#011validation-rmse:4.62575
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 2 pruned nodes, max_depth=5
[40]#011train-rmse:1.34878#011validation-rmse:4.60862
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 6 pruned nodes, max_depth=4
[41]#011train-rmse:1.34519#011validation-rmse:4.60156
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 20 pruned nodes, max_depth=5
[42]#011train-rmse:1.30085#011validation-rmse:4.58674
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 8 pruned nodes, max_depth=5
[43]#011train-rmse:1.28972#011validation-rmse:4.58649
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 14 pruned nodes, max_depth=2
[44]#011train-rmse:1.27292#011validation-rmse:4.60772
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 10 pruned nodes, max_depth=4
[45]#011train-rmse:1.24397#011validation-rmse:4.60602
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 12 pruned nodes, max_depth=2
[46]#011train-rmse:1.23392#011validation-rmse:4.57387
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 10 pruned nodes, max_depth=3
[47]#011train-rmse:1.21994#011validation-rmse:4.57972
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 4 pruned nodes, max_depth=5
[48]#011train-rmse:1.20531#011validation-rmse:4.59388
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 4 pruned nodes, max_depth=3
[49]#011train-rmse:1.18616#011validation-rmse:4.58165
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 8 pruned nodes, max_depth=5
[50]#011train-rmse:1.16562#011validation-rmse:4.58707
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 12 pruned nodes, max_depth=5
[51]#011train-rmse:1.15204#011validation-rmse:4.58624
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 4 pruned nodes, max_depth=5
[52]#011train-rmse:1.14192#011validation-rmse:4.57551
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 8 pruned nodes, max_depth=4
[53]#011train-rmse:1.13449#011validation-rmse:4.5764
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 10 pruned nodes, max_depth=5
[54]#011train-rmse:1.10662#011validation-rmse:4.5707
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 2 pruned nodes, max_depth=5
[55]#011train-rmse:1.09465#011validation-rmse:4.57936
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 10 pruned nodes, max_depth=2
[56]#011train-rmse:1.08436#011validation-rmse:4.569
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 16 pruned nodes, max_depth=4
[57]#011train-rmse:1.06706#011validation-rmse:4.58951
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 18 pruned nodes, max_depth=2
[58]#011train-rmse:1.04869#011validation-rmse:4.55725
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 14 pruned nodes, max_depth=2
[59]#011train-rmse:1.04649#011validation-rmse:4.56979
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[60]#011train-rmse:1.04649#011validation-rmse:4.56979
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 16 pruned nodes, max_depth=4
[61]#011train-rmse:1.02877#011validation-rmse:4.54751
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 18 pruned nodes, max_depth=4
[62]#011train-rmse:1.01628#011validation-rmse:4.53898
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 14 pruned nodes, max_depth=3
[63]#011train-rmse:1.0052#011validation-rmse:4.51775
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 4 pruned nodes, max_depth=5
[64]#011train-rmse:0.998801#011validation-rmse:4.51989
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 2 extra nodes, 10 pruned nodes, max_depth=1
[65]#011train-rmse:0.997339#011validation-rmse:4.52241
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 20 pruned nodes, max_depth=2
[66]#011train-rmse:0.990396#011validation-rmse:4.50933
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 2 extra nodes, 14 pruned nodes, max_depth=1
[67]#011train-rmse:0.989698#011validation-rmse:4.51075
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 12 pruned nodes, max_depth=3
[68]#011train-rmse:0.983383#011validation-rmse:4.50904
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 8 pruned nodes, max_depth=3
[69]#011train-rmse:0.97871#011validation-rmse:4.51974
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 10 pruned nodes, max_depth=3
[70]#011train-rmse:0.972099#011validation-rmse:4.51708
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 14 pruned nodes, max_depth=0
[71]#011train-rmse:0.97215#011validation-rmse:4.5164
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 18 pruned nodes, max_depth=3
[72]#011train-rmse:0.956671#011validation-rmse:4.49632
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[73]#011train-rmse:0.956678#011validation-rmse:4.4967
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[74]#011train-rmse:0.956679#011validation-rmse:4.49672
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 16 pruned nodes, max_depth=2
[75]#011train-rmse:0.951738#011validation-rmse:4.49589
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 10 pruned nodes, max_depth=5
[76]#011train-rmse:0.940893#011validation-rmse:4.50138
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 6 pruned nodes, max_depth=4
[77]#011train-rmse:0.933318#011validation-rmse:4.49635
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 20 pruned nodes, max_depth=4
[78]#011train-rmse:0.92265#011validation-rmse:4.48439
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 6 pruned nodes, max_depth=5
[79]#011train-rmse:0.907698#011validation-rmse:4.4905
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 14 pruned nodes, max_depth=3
[80]#011train-rmse:0.89772#011validation-rmse:4.47185
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[81]#011train-rmse:0.897707#011validation-rmse:4.4717
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 30 pruned nodes, max_depth=0
[82]#011train-rmse:0.89774#011validation-rmse:4.47204
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 18 pruned nodes, max_depth=2
[83]#011train-rmse:0.894762#011validation-rmse:4.4904
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 18 pruned nodes, max_depth=4
[84]#011train-rmse:0.88249#011validation-rmse:4.48807
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 14 pruned nodes, max_depth=5
[85]#011train-rmse:0.868593#011validation-rmse:4.46054
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 14 pruned nodes, max_depth=2
[86]#011train-rmse:0.860016#011validation-rmse:4.45059
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 20 pruned nodes, max_depth=0
[87]#011train-rmse:0.860011#011validation-rmse:4.45036
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 30 pruned nodes, max_depth=0
[88]#011train-rmse:0.860011#011validation-rmse:4.45035
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 22 pruned nodes, max_depth=0
[89]#011train-rmse:0.860025#011validation-rmse:4.45083
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 30 pruned nodes, max_depth=0
[90]#011train-rmse:0.86003#011validation-rmse:4.45092
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 18 pruned nodes, max_depth=2
[91]#011train-rmse:0.854747#011validation-rmse:4.44436
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[92]#011train-rmse:0.85472#011validation-rmse:4.44378
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 30 pruned nodes, max_depth=0
[93]#011train-rmse:0.854758#011validation-rmse:4.44451
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[94]#011train-rmse:0.854715#011validation-rmse:4.44333
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[95]#011train-rmse:0.854718#011validation-rmse:4.4431
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 20 pruned nodes, max_depth=2
[96]#011train-rmse:0.849772#011validation-rmse:4.44354
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[97]#011train-rmse:0.849783#011validation-rmse:4.44248
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 20 pruned nodes, max_depth=0
[98]#011train-rmse:0.849778#011validation-rmse:4.44258
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 10 pruned nodes, max_depth=4
[99]#011train-rmse:0.842667#011validation-rmse:4.43031
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[100]#011train-rmse:0.842665#011validation-rmse:4.43007
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 8 pruned nodes, max_depth=5
[101]#011train-rmse:0.830601#011validation-rmse:4.43165
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[102]#011train-rmse:0.830611#011validation-rmse:4.43188
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 22 pruned nodes, max_depth=0
[103]#011train-rmse:0.830591#011validation-rmse:4.43114
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 24 pruned nodes, max_depth=0
[104]#011train-rmse:0.830591#011validation-rmse:4.43118
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 26 pruned nodes, max_depth=2
[105]#011train-rmse:0.825438#011validation-rmse:4.43114
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 20 pruned nodes, max_depth=0
[106]#011train-rmse:0.825473#011validation-rmse:4.43063
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[107]#011train-rmse:0.825423#011validation-rmse:4.43199
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 28 pruned nodes, max_depth=0
[108]#011train-rmse:0.825422#011validation-rmse:4.43163
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[109]#011train-rmse:0.825427#011validation-rmse:4.43142
[20:34:22] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[110]#011train-rmse:0.82543#011validation-rmse:4.43133
Stopping. Best iteration:
[100]#011train-rmse:0.842665#011validation-rmse:4.43007

Training seconds: 64
Billable seconds: 64

Step 5: Test the model

Now that we have fit our model to the training data, using the validation data to avoid overfitting, we can test our model. To do this we will make use of SageMaker’s Batch Transform functionality. To start with, we need to build a transformer object from our fit model.

xgb_transformer = xgb.transformer(instance_count = 1, instance_type = 'ml.m4.xlarge')
Parameter image will be renamed to image_uri in SageMaker Python SDK v2.

Next we ask SageMaker to begin a batch transform job using our trained model and applying it to the test data we previously stored in S3. We need to make sure to provide SageMaker with the type of data that we are providing to our model, in our case text/csv, so that it knows how to serialize our data. In addition, we need to make sure to let SageMaker know how to split our data up into chunks if the entire data set happens to be too large to send to our model all at once.

Note that when we ask SageMaker to do this it will execute the batch transform job in the background. Since we need to wait for the results of this job before we can continue, we use the wait() method. An added benefit of this is that we get some output from our batch transform job which lets us know if anything went wrong.

xgb_transformer.transform(test_location, content_type='text/csv', split_type='Line')
xgb_transformer.wait()
..............................
.Arguments: serve
Arguments: serve
[2020-12-01 20:49:04 +0000] [1] [INFO] Starting gunicorn 19.7.1
[2020-12-01 20:49:04 +0000] [1] [INFO] Listening at: http://0.0.0.0:8080 (1)
[2020-12-01 20:49:04 +0000] [1] [INFO] Using worker: gevent
[2020-12-01 20:49:04 +0000] [36] [INFO] Booting worker with pid: 36
[2020-12-01 20:49:04 +0000] [37] [INFO] Booting worker with pid: 37
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 36
[2020-12-01 20:49:04 +0000] [38] [INFO] Booting worker with pid: 38
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 37
[2020-12-01 20:49:04 +0000] [39] [INFO] Booting worker with pid: 39
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 38
[2020-12-01:20:49:04:INFO] Sniff delimiter as ','
[2020-12-01:20:49:04:INFO] Determined delimiter of CSV input is ','
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 39
[2020-12-01 20:49:04 +0000] [1] [INFO] Starting gunicorn 19.7.1
[2020-12-01 20:49:04 +0000] [1] [INFO] Listening at: http://0.0.0.0:8080 (1)
[2020-12-01 20:49:04 +0000] [1] [INFO] Using worker: gevent
[2020-12-01 20:49:04 +0000] [36] [INFO] Booting worker with pid: 36
[2020-12-01 20:49:04 +0000] [37] [INFO] Booting worker with pid: 37
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 36
[2020-12-01 20:49:04 +0000] [38] [INFO] Booting worker with pid: 38
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 37
[2020-12-01 20:49:04 +0000] [39] [INFO] Booting worker with pid: 39
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 38
[2020-12-01:20:49:04:INFO] Sniff delimiter as ','
[2020-12-01:20:49:04:INFO] Determined delimiter of CSV input is ','
[2020-12-01:20:49:04:INFO] Model loaded successfully for worker : 39
2020-12-01T20:49:04.829:[sagemaker logs]: MaxConcurrentTransforms=4, MaxPayloadInMB=6, BatchStrategy=MULTI_RECORD

Now that the batch transform job has finished, the resulting output is stored on S3. Since we wish to analyze the output inside of our notebook we can use a bit of notebook magic to copy the output file from its S3 location and save it locally.

!aws s3 cp --recursive $xgb_transformer.output_path $data_dir
download: s3://sagemaker-us-east-1-816427681933/xgboost-2020-12-01-20-43-57-151/test.csv.out to ../data/boston/test.csv.out

To see how well our model works we can create a simple scatter plot between the predicted and actual values. If the model was completely accurate the resulting scatter plot would look like the line $x=y$. As we can see, our model seems to have done okay but there is room for improvement.

Y_pred = pd.read_csv(os.path.join(data_dir, 'test.csv.out'), header=None)
plt.scatter(Y_test, Y_pred)
plt.xlabel("Median Price")
plt.ylabel("Predicted Price")
plt.title("Median Price vs Predicted Price")
Text(0.5, 1.0, 'Median Price vs Predicted Price')

png

Optional: Clean up

The default notebook instance on SageMaker doesn’t have a lot of excess disk space available. As you continue to complete and execute notebooks you will eventually fill up this disk space, leading to errors which can be difficult to diagnose. Once you are completely finished using a notebook it is a good idea to remove the files that you created along the way. Of course, you can do this from the terminal or from the notebook hub if you would like. The cell below contains some commands to clean up the created files from within the notebook.

# First we will remove all of the files contained in the data_dir directory
!rm $data_dir/*

# And then we delete the directory itself
!rmdir $data_dir
!pip install sagemaker==1.72.0
Collecting sagemaker==1.72.0
  Downloading sagemaker-1.72.0.tar.gz (297 kB)
     |████████████████████████████████| 297 kB 11.2 MB/s eta 0:00:01
[?25hRequirement already satisfied: boto3>=1.14.12 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (1.16.9)
Requirement already satisfied: numpy>=1.9.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (1.18.1)
Requirement already satisfied: protobuf>=3.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (3.11.4)
Requirement already satisfied: scipy>=0.19.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (1.4.1)
Requirement already satisfied: protobuf3-to-dict>=0.1.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (0.1.5)
Collecting smdebug-rulesconfig==0.1.4
  Downloading smdebug_rulesconfig-0.1.4-py2.py3-none-any.whl (10 kB)
Requirement already satisfied: importlib-metadata>=1.4.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (2.0.0)
Requirement already satisfied: packaging>=20.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker==1.72.0) (20.1)
Requirement already satisfied: botocore<1.20.0,>=1.19.9 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from boto3>=1.14.12->sagemaker==1.72.0) (1.19.9)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from boto3>=1.14.12->sagemaker==1.72.0) (0.10.0)
Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from boto3>=1.14.12->sagemaker==1.72.0) (0.3.3)
Requirement already satisfied: six>=1.9 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from protobuf>=3.1->sagemaker==1.72.0) (1.14.0)
Requirement already satisfied: setuptools in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from protobuf>=3.1->sagemaker==1.72.0) (45.2.0.post20200210)
Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from importlib-metadata>=1.4.0->sagemaker==1.72.0) (2.2.0)
Requirement already satisfied: pyparsing>=2.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from packaging>=20.0->sagemaker==1.72.0) (2.4.6)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from botocore<1.20.0,>=1.19.9->boto3>=1.14.12->sagemaker==1.72.0) (2.8.1)
Requirement already satisfied: urllib3<1.26,>=1.25.4; python_version != "3.4" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from botocore<1.20.0,>=1.19.9->boto3>=1.14.12->sagemaker==1.72.0) (1.25.10)
Building wheels for collected packages: sagemaker
  Building wheel for sagemaker (setup.py) ... [?25ldone
[?25h  Created wheel for sagemaker: filename=sagemaker-1.72.0-py2.py3-none-any.whl size=386358 sha256=2f3fd21b8bfb95af0820a92495a3f7224918c14ff6a797d3b02ab416805d8060
  Stored in directory: /home/ec2-user/.cache/pip/wheels/c3/58/70/85faf4437568bfaa4c419937569ba1fe54d44c5db42406bbd7
Successfully built sagemaker
Installing collected packages: smdebug-rulesconfig, sagemaker
  Attempting uninstall: smdebug-rulesconfig
    Found existing installation: smdebug-rulesconfig 0.1.5
    Uninstalling smdebug-rulesconfig-0.1.5:
      Successfully uninstalled smdebug-rulesconfig-0.1.5
  Attempting uninstall: sagemaker
    Found existing installation: sagemaker 2.16.3.post0
    Uninstalling sagemaker-2.16.3.post0:
      Successfully uninstalled sagemaker-2.16.3.post0
Successfully installed sagemaker-1.72.0 smdebug-rulesconfig-0.1.4
WARNING: You are using pip version 20.0.2; however, version 20.3 is available.
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.