Amazon SageMaker has redesigned its Python SDK to provide a unified object-oriented interface that makes it straightforward to interact with SageMaker services. The new SDK is designed with a tiered user experience in mind, where the new lower-level SDK (SageMaker Core) provides access to full breadth of SageMaker features and configurations, allowing for greater flexibility and control for ML engineers. The higher-level abstracted layer is designed for data scientists with limited AWS expertise, offering a simplified interface that hides complex infrastructure details.
In this two-part series, we introduce the abstracted layer of the SageMaker Python SDK that allows you to train and deploy machine learning (ML) models by using the new ModelTrainer and the improved ModelBuilder classes.
In this post, we focus on the ModelTrainer class for simplifying the training experience. The ModelTrainer class provides significant improvements over the current Estimator class, which are discussed in detail in this post. We show you how to use the ModelTrainer class to train your ML models, which includes executing distributed training using a custom script or container. In Part 2, we show you how to build a model and deploy to a SageMaker endpoint using the improved ModelBuilder class.
Benefits of the ModelTrainer class
The new ModelTrainer class has been designed to address usability challenges associated with Estimator class. Moving forward, ModelTrainer will be the preferred approach for model training, bringing significant enhancements that greatly improve the user experience. This evolution marks a step towards achieving a best-in-class developer experience for model training. The following are the key benefits:
- Improved intuitiveness – The ModelTrainer class reduces complexity by consolidating configurations into just few core parameters. This streamlining minimizes cognitive overload, allowing users to focus on model training rather than configuration intricacies. Additionally, it employs intuitive config classes for straightforward platform interactions.
- Simplified script mode and BYOC – Transitioning from local development to cloud training is now seamless. The ModelTrainer automatically maps source code, data paths, and parameter specifications to the remote execution environment, eliminating the need for special handshakes or complex setup processes.
- Simplified distributed training – The ModelTrainer class provides enhanced flexibility for users to specify custom commands and distributed training strategies, allowing you to directly provide the exact command you want to run in your container through the
command
parameter in theSourceCode
This approach decouples distributed training strategies from the training toolkit and framework-specific estimators. - Improved hyperparameter contracts – The ModelTrainer class passes the training job’s hyperparameters as a single environment variable, allowing the you to load the hyperparameters using a single
SM_HPS
variable.
To further explain each of these benefits, we demonstrate with examples in the following sections, and finally show you how to set up and run distributed training for the Meta Llama 3.1 8B model using the new ModelTrainer
class.
Launch a training job using the ModelTrainer class
The ModelTrainer class simplifies the experience by letting you customize the training job, including providing a custom script, directly providing a command to run the training job, supporting local mode, and much more. However, you can spin up a SageMaker training job in script mode by providing minimal parameters—the SourceCode
and the training image URI.
The following example illustrates how you can launch a training job with your own custom script by providing just the script and the training image URI (in this case, PyTorch), and an optional requirements file. Additional parameters such as the instance type and instance size are automatically set by the SDK to preset defaults, and parameters such as the AWS Identity and Access Management (IAM) role and SageMaker session are automatically detected from the current session and user’s credentials. Admins and users can also overwrite the defaults using the SDK defaults configuration file. For the detailed list of pre-set values, refer to the SDK documentation.
With purpose-built configurations, you can now reuse these objects to create multiple training jobs with different hyperparameters, for example, without having to re-define all the parameters.
Run the job locally for experimentation
To run the preceding training job locally, you can simply set the training_mode
parameter as shown in the following code:
The training job runs remotely because training_mode
is set to Mode.LOCAL_CONTAINER
. If not explicitly set, the ModelTrainer runs a remote SageMaker training job by default. This behavior can also be enforced by changing the value to Mode.SAGEMAKER_TRAINING_JOB
. For a full list of the available configs, including compute and networking, refer to the SDK documentation.
Read hyperparameters in your custom script
The ModelTrainer supports multiple ways to read the hyperparameters that are passed to a training job. In addition to the existing support to read the hyperparameters as command line arguments in your custom script, ModelTrainer also supports reading the hyperparameters as individual environment variables, prefixed with SM_HPS_<hyperparameter-key>
, or as a single environment variable dictionary, SM_HPS
.
Suppose the following hyperparameters are passed to the training job:
You have the following options:
- Option 1 – Load the hyperparameters into a single JSON dictionary using the
SM_HPS
environment variable in your custom script:
- Option 2 – Read the hyperparameters as individual environment variables, prefixed by
SM_HP
as shown in the following code (you need to explicitly specify the correct input type for these variables):
- Option 3 – Read the hyperparameters as AWS CLI arguments using
parse.args
:
Run distributed training jobs
SageMaker supports distributed training to support training for deep learning tasks such as natural language processing and computer vision, to run secure and scalable data parallel and model parallel jobs. This is usually achieved by providing the right set of parameters when using an Estimator. For example, to use torchrun
, you would define the distribution
parameter in the PyTorch Estimator and set it to "torch_distributed": {"enabled": True}
.
The ModelTrainer class provides enhanced flexibility for users to specify custom commands directly through the command
parameter in the SourceCode
class, and supports torchrun
, torchrun smp
, and the MPI strategies. This capability is particularly useful when you need to launch a job with a custom launcher command that is not supported by the training toolkit.
In the following example, we show how to fine-tune the latest Meta Llama 3.1 8B model using the default launcher script using Torchrun on a custom dataset that’s preprocessed and saved in an Amazon Simple Storage Service (Amazon S3) location:
If you wanted to customize your torchrun
launcher script, you can also directly provide the commands using the command
parameter:
For more examples and end-to-end ML workflows using the SageMaker ModelTrainer, refer to the GitHub repo.
Conclusion
The newly launched SageMaker ModelTrainer class simplifies the user experience by reducing the number of parameters, introducing intuitive configurations, and supporting complex setups like bringing your own container and running distributed training. Data scientists can also seamlessly transition from local training to remote training and training on multiple nodes using the ModelTrainer.
We encourage you to try out the ModelTrainer class by referring to the SDK documentation and sample notebooks on the GitHub repo. The ModelTrainer class is available from the SageMaker SDK v2.x onwards, at no additional charge. In Part 2 of this series, we show you how to build a model and deploy to a SageMaker endpoint using the improved ModelBuilder class.
About the Authors
Durga Sury is a Senior Solutions Architect on the Amazon SageMaker team. Over the past 5 years, she has worked with multiple enterprise customers to set up a secure, scalable AI/ML platform built on SageMaker.
Shweta Singh is a Senior Product Manager in the Amazon SageMaker Machine Learning (ML) platform team at AWS, leading SageMaker Python SDK. She has worked in several product roles in Amazon for over 5 years. She has a Bachelor of Science degree in Computer Engineering and a Masters of Science in Financial Engineering, both from New York University.