Set up cross-account Amazon S3 access for Amazon SageMaker notebooks in VPC-only mode using Amazon S3 Access Points

Set up cross-account Amazon S3 access for Amazon SageMaker notebooks in VPC-only mode using Amazon S3 Access Points

Advancements in artificial intelligence (AI) and machine learning (ML) are revolutionizing the financial industry for use cases such as fraud detection, credit worthiness assessment, and trading strategy optimization. To develop models for such use cases, data scientists need access to various datasets like credit decision engines, customer transactions, risk appetite, and stress testing. Managing appropriate access control for these datasets among the data scientists working on them is crucial to meet stringent compliance and regulatory requirements. Typically, these datasets are aggregated in a centralized Amazon Simple Storage Service (Amazon S3) location from various business applications and enterprise systems. Data scientists across business units working on model development using Amazon SageMaker are granted access to relevant data, which can lead to the requirement of managing prefix-level access controls. With an increase in use cases and datasets using bucket policy statements, managing cross-account access per application is too complex and long for a bucket policy to accommodate.

Amazon S3 Access Points simplify managing and securing data access at scale for applications using shared datasets on Amazon S3. You can create unique hostnames using access points to enforce distinct and secure permissions and network controls for any request made through the access point.

S3 Access Points simplifies the management of access permissions specific to each application accessing a shared dataset. It enables secure, high-speed data copy between same-Region access points using AWS internal networks and VPCs. S3 Access Points can restrict access to VPCs, enabling you to firewall data within private networks, test new access control policies without impacting existing access points, and configure VPC endpoint policies to restrict access to specific account ID-owned S3 buckets.

This post walks through the steps involved in configuring S3 Access Points to enable cross-account access from a SageMaker notebook instance.

Solution overview

For our use case, we have two accounts in an organization: Account A (111111111111), which is used by data scientists to develop models using a SageMaker notebook instance, and Account B (222222222222), which has required datasets in the S3 bucket test-bucket-1. The following diagram illustrates the solution architecture.

To implement the solution, complete the following high-level steps:

  1. Configure Account A, including VPC, subnet security group, VPC gateway endpoint, and SageMaker notebook.
  2. Configure Account B, including S3 bucket, access point, and bucket policy.
  3. Configure AWS Identity and Access Management (IAM) permissions and policies in Account A.

You should repeat these steps for each SageMaker account that needs access to the shared dataset from Account B.

The names for each resource mentioned in this post are examples; you can replace them with other names as per your use case.

Configure Account A

Complete the following steps to configure Account A:

  1. Create a VPC called DemoVPC.
  2. Create a subnet called DemoSubnet in the VPC DemoVPC.
  3. Create a security group called DemoSG.
  4. Create a VPC S3 gateway endpoint called DemoS3GatewayEndpoint.
  5. Create the SageMaker execution role.
  6. Create a notebook instance called DemoNotebookInstance and the security guidelines as outlined in How to configure security in Amazon SageMaker.
    1. Specify the Sagemaker execution role you created.
    2. For the notebook network settings, specify the VPC, subnet, and security group you created.
    3. Make sure that Direct Internet access is disabled.

You assign permissions to the role in subsequent steps after you create the required dependencies.

Configure Account B

To configure Account B, complete the following steps:

  1. In Account B, create an S3 bucket called test-bucket-1 following Amazon S3 security guidance.
  2. Upload your file to the S3 bucket.
  3. Create an access point called test-ap-1 in Account B.
    1. Don’t change or edit any Block Public Access settings for this access point (all public access should be blocked).
  4. Attach the following policy to your access point:
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "AWS": “arn:aws:iam:: 111111111111:role/demo ”
            },
            "Action": ["s3:GetObject", "s3:GetObjectVersion", "s3:PutObject", "s3:PutObjectAcl"]
            "Resource": [
                “arn:aws:s3:us-east-1: 222222222222:accesspoint/test-ap-1”,
                " arn:aws:s3:us-east-1: 222222222222:accesspoint/test-ap-1/object/*"
            ]
        }
    ]
}

The actions defined in the preceding code are sample actions for demonstration purposes. You can define the actions as per your requirements or use case.

  1. Add the following bucket policy permissions to access the access point:
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "AWS": " arn:aws:iam:: 111111111111:role/demo "
            },
            "Action" : ["s3:GetObject","s3:ListBucket"],
            "Resource" : ["arn:aws:s3:::test-bucket-1 ”, " arn:aws:s3:::test-bucket-1/*"]
            "Condition": {
                "StringEquals": {
                    "s3:DataAccessPointAccount": "222222222222"
                }
            }
        }
    ]
}

The preceding actions are examples. You can define the actions as per your requirements.

Configure IAM permissions and policies

Complete the following steps in Account A:

  1. Confirm that the SageMaker execution role has the AmazonSagemakerFullAccess custom IAM inline policy, which looks like the following code:
{
            "Sid": "VisualEditor2",
            "Effect": "Allow",
            " Action": ["s3:GetObject", "s3:GetObjectVersion", "s3:PutObject", "s3:PutObjectAcl"]
            "Resource": [
                “arn:aws:s3:us-east-1: 222222222222:accesspoint/test-ap-1 ”,
                "arn:aws:s3:us-east-1: 222222222222:accesspoint/test-ap-1 /object/*”,                             "arn:aws:s3:::test-bucket-1”,
                " arn:aws:s3:::test-bucket-1/*"
            ]
}

The actions in the policy code are sample actions for demonstration purposes.

  1. Go to the DemoS3GatewayEndpoint endpoint you created and add the following permissions:
{

	"Version": "2012-10-17",
	"Statement": [
		{
			"Sid": "AllowCrossAccountAccessThroughAccessPoint",
			"Effect": "Allow",
			"Principal": "*",
			"Action": [
				"s3:Get*",
				"s3:List*",
				"s3:Put*"
			],
			"Resource": ": [
                “arn:aws:s3:us-east-1: 222222222222:accesspoint/test-ap-1 ”,
                "arn:aws:s3:us-east-1: 222222222222:accesspoint/test-ap-1 /object/*”,                             "arn:aws:s3:::test-bucket-1 ”,
                " arn:aws:s3:::test-bucket-1/*"
            ]
 
		}
	]
}
  1. To get a prefix list, run the AWS Command Line Interface (AWS CLI) describe-prefix-lists command:
aws ec2 describe-prefix-lists
  1. In Account A, Go to the security group DemoSG for the target SageMaker notebook instance
  2. Under Outbound rules, create an outbound rule with All traffic or All TCP, and then specify the destination as the prefix list ID you retrieved.

This completes the setup in both accounts.

Test the solution

To validate the solution, go to the SageMaker notebook instance terminal and enter the following commands to list the objects through the access point:

  • To list the objects successfully through S3 access point test-ap-1:
aws s3 ls arn:aws:s3:us-east-1:222222222222:accesspoint/Test-Ap-1

  • To get the objects successfully through S3 access point test-ap-1:
aws s3api get-object --bucket arn:aws:s3:us-east-1:222222222222:accesspoint/test-ap-1 --key sample2.csv test2.csv

Clean up

When you’re done testing, delete any S3 access points and S3 buckets. Also, delete any Sagemaker notebook instances to stop incurring charges.

Conclusion

In this post, we showed how S3 Access Points enables cross-account access to large, shared datasets from SageMaker notebook instances, bypassing size constraints imposed by bucket policies while configuring at-scale access management on shared datasets.

To learn more, refer to Easily Manage Shared Data Sets with Amazon S3 Access Points.


About the authors

Kiran Khambete is working as Senior Technical Account Manager at Amazon Web Services (AWS). As a TAM, Kiran plays a role of technical expert and strategic guide to helping Enterprise customers achieving their business goals.

Ankit Soni with total experience of 14 years holds the position of Principal Engineer at NatWest Group, where he has served as a Cloud Infrastructure Architect for the past six years.

Kesaraju Sai Sandeep is a Cloud Engineer specializing in Big Data Services at AWS.

Read More

AI Decoded: Demystifying Large Language Models, the Brains Behind Chatbots

AI Decoded: Demystifying Large Language Models, the Brains Behind Chatbots

Editor’s note: This post is part of our AI Decoded series, which aims to demystify AI by making the technology more accessible, while showcasing new hardware, software, tools and accelerations for RTX PC and workstation users.

If AI is having its iPhone moment, then chatbots are one of its first popular apps.

They’re made possible thanks to large language models, deep learning algorithms pretrained on massive datasets — as expansive as the internet itself — that can recognize, summarize, translate, predict and generate text and other forms of content. They can run locally on PCs and workstations powered by NVIDIA GeForce and RTX GPUs.

LLMs excel at summarizing large volumes of text, classifying and mining data for insights, and generating new text in a user-specified style, tone or format. They can facilitate communication in any language, even beyond ones spoken by humans, such as computer code or protein and genetic sequences.

While the first LLMs dealt solely with text, later iterations were trained on other types of data. These multimodal LLMs can recognize and generate images, audio, videos and other content forms.

Chatbots like ChatGPT were among the first to bring LLMs to a consumer audience, with a familiar interface built to converse with and respond to natural-language prompts. LLMs have since been used to help developers write code and scientists to drive drug discovery and vaccine development.

But the AI models that power those functions are computationally intensive. Combining advanced optimization techniques and algorithms like quantization with RTX GPUs, which are purpose-built for AI, helps make LLMs compact enough and PCs powerful enough to run locally — no internet connection required. And a new breed of lightweight LLMs like Mistral — one of the LLMs powering Chat with RTX — sets the stage for state-of-the-art performance with lower power and storage demands.

Why Do LLMs Matter?

LLMs can be adapted for a wide range of use cases, industries and workflows. This versatility, combined with their high-speed performance, offers performance and efficiency gains across virtually all language-based tasks.

DeepL, running on NVIDIA GPUs in the cloud, uses advanced AI to provide accurate text translations.

LLMs are widely used in language translation apps such as DeepL, which uses AI and machine learning to provide accurate outputs.

Medical researchers are training LLMs on textbooks and other medical data to enhance patient care. Retailers are leveraging LLM-powered chatbots to deliver stellar customer support experiences. Financial analysts are tapping LLMs to transcribe and summarize earning calls and other important meetings. And that’s just the tip of the iceberg.

Chatbots — like Chat with RTX — and writing assistants built atop LLMs are making their mark on every facet of knowledge work, from content marketing and copywriting to legal operations. Coding assistants were among the first LLM-powered applications to point toward the AI-assisted future of software development. Now, projects like ChatDev are combining LLMs with AI agents — smart bots that act autonomously to help answer questions or perform digital tasks — to spin up an on-demand, virtual software company. Just tell the system what kind of app is needed and watch it get to work.

Learn more about LLM agents on the NVIDIA developer blog.

Easy as Striking Up a Conversation 

Many people’s first encounter with generative AI came by way of a chatbot such as ChatGPT, which simplifies the use of LLMs through natural language, making user action as simple as telling the model what to do.

LLM-powered chatbots can help generate a draft of marketing copy, offer ideas for a vacation, craft an email to customer service and even spin up original poetry.

Advances in image generation and multimodal LLMs have extended the chatbot’s realm to include analyzing and generating imagery — all while maintaining the wonderfully simple user experience. Just describe an image to the bot or upload a photo and ask the system to analyze it. It’s chatting, but now with visual aids.

For more on how these bots are designed, check out the on-demand webinar on Building Intelligent AI Chatbots Using RAG.

Future advancements will help LLMs expand their capacity for logic, reasoning, math and more, giving them the ability to break complex requests into smaller subtasks.

Progress is also being made on AI agents, applications capable of taking a complex prompt, breaking it into smaller ones, and engaging autonomously with LLMs and other AI systems to complete them. ChatDev is an example of an AI agent framework, but agents aren’t limited to technical tasks.

For example, users could ask a personal AI travel agent to book a family vacation abroad. The agent would break that task into subtasks — itinerary planning, booking travel and lodging, creating packing lists, finding a dog walker — and independently execute them in order.

Unlock Personal Data With RAG

As powerful as LLMs and chatbots are for general use, they can become even more helpful when combined with an individual user’s data. By doing so, they can help analyze email inboxes to uncover trends, comb through dense user manuals to find the answer to a technical question about some hardware, or summarize years of bank and credit card statements.

Retrieval-augmented generation, or RAG, is one of the easiest and most effective ways to hone LLMs for a particular dataset.

An example of RAG on a PC.

RAG enhances the accuracy and reliability of generative AI models with facts fetched from external sources. By connecting an LLM with practically any external resource, RAG lets users chat with data repositories while also giving the LLM the ability to cite its sources. The user experience is as simple as pointing the chatbot toward a file or directory.

For example, a standard LLM will have general knowledge about content strategy best practices, marketing tactics and basic insights into a particular industry or customer base. But connecting it via RAG to marketing assets supporting a product launch would allow it to analyze the content and help plan a tailored strategy.

RAG works with any LLM, as the application supports it. NVIDIA’s Chat with RTX tech demo is an example of RAG connecting an LLM to a personal dataset. It runs locally on systems with a GeForce RTX or NVIDIA RTX professional GPU.

To learn more about RAG and how it compares to fine-tuning an LLM, read the tech blog, RAG 101: Retrieval-Augmented Generation Questions Answered.

Experience the Speed and Privacy of Chat with RTX

Chat with RTX is a local, personalized chatbot demo that’s easy to use and free to download. It’s built with RAG functionality and TensorRT-LLM and RTX acceleration. It supports multiple open-source LLMs, including Meta’s Llama 2 and Mistral’s Mistral. Support for Google’s Gemma is coming in a future update.

Chat with RTX connects users to their personal data through RAG.

Users can easily connect local files on a PC to a supported LLM simply by dropping files into a folder and pointing the demo to that location. Doing so enables it to answer queries with quick, contextually relevant answers.

Since Chat with RTX runs locally on Windows with GeForce RTX PCs and NVIDIA RTX workstations, results are fast — and the user’s data stays on the device. Rather than relying on cloud-based services, Chat with RTX lets users process sensitive data on a local PC without the need to share it with a third party or have an internet connection.

To learn more about how AI is shaping the future, tune in to NVIDIA GTC, a global AI developer conference running March 18-21 in San Jose, Calif., and online.

Read More

Currents of Change: ITIF President Daniel Castro on Energy-Efficient AI and Climate Change

Currents of Change: ITIF President Daniel Castro on Energy-Efficient AI and Climate Change

AI-driven change is in the air, as are concerns about the technology’s environmental impact. In this episode of NVIDIA’s AI Podcast, Daniel Castro, vice president of the Information Technology and Innovation Foundation and director of its Center for Data Innovation, speaks with host Noah Kravitz about the motivation behind his AI energy use report, which addresses misconceptions about the technology’s energy consumption. Castro also touches on the need for policies and frameworks that encourage the development of energy-efficient technology. Tune in to discover the crucial role of GPU acceleration in enhancing sustainability and how AI can help address climate change challenges.

Register for NVIDIA GTC, a global AI developer conference running March 18-21 in San Jose, Calif., to explore sessions on energy-efficient computing and using AI to combat climate change.

You Might Also Like…

Overjet on Bringing AI to Dentistry – Ep. 179

Dentists get a bad rap. Dentists also get more people out of more aggravating pain than just about anyone, which is why the more technology dentists have, the better. Overjet, a member of the NVIDIA Inception program for startups, is moving fast to bring AI to dentists’ offices.

DigitalPath’s Ethan Higgins on Using AI to Fight Wildires – Ep. 211

DigitalPath is igniting change in the golden state — using computer vision, generative adversarial networks and a network of thousands of cameras to detect signs of fire in real-time.

Anima Anandkumar on Using Generative AI to Tackle Global Challenges – Ep. 204

Anima Anandkumar, Bren Professor at Caltech and senior director of AI research at NVIDIA, speaks to generative AI’s potential to make splashes in the scientific community, from accelerating drug and vaccine research to predicting extreme weather events like hurricanes or heat waves.

Doing the Best They Can: EverestLabs Ensures Fewer Recyclables Go to Landfills – Ep. 184

All of us recycle. Or, at least, all of us should. Now, AI is joining the effort. JD Ambati, founder and CEO of EverestLabs, developer of RecycleOS, discusses developing first AI-enabled operating system for recycling.

Show Notes

1:41: Context on and findings from the AI energy use report
10:36: How GPU acceleration has transformed the energy efficiency of AI, particularly in weather and climate forecasting
12:31: Examples of how GPU acceleration has improved the energy efficiency of AI operations
15:51: Castro’s insights on sustainability and AI
20:01: Policies and frameworks to encourage energy-efficient AI
26:43: Castro’s outlook on the interplay among advancing AI technology, energy sustainability and climate change

Subscribe to the AI Podcast

Get the AI Podcast through iTunes, Google Podcasts, Google Play, Amazon Music, Castbox, DoggCatcher, Overcast, PlayerFM, Pocket Casts, Podbay, PodBean, PodCruncher, PodKicker, Soundcloud, Spotify, Stitcher and TuneIn.

Make the AI Podcast better: Have a few minutes to spare? Fill out this listener survey.

 

Read More

Maximizing training throughput using PyTorch FSDP

Maximizing training throughput using PyTorch FSDP

In this blog, we demonstrate the scalability of FSDP with a pre-training exemplar, a 7B model trained for 2T tokens, and share various techniques we used to achieve a rapid training speed of 3,700 tokens/sec/GPU, or 40B tokens/day on 128 A100 GPUs. This translates to a model FLOPS utilization (MFU) and hardware FLOPS utilization (HFU) of 57%. Additionally, we have observed near linear scaling of FSDP to 512 GPUs, implying that training a 7B model on 512 GPUs to 2T tokens using this method would take just under two weeks.

IBM researchers trained a Meta Llama 2 7B architecture to 2T tokens, which we will refer to as LlamaT(est). This model demonstrates comparable model quality as Llama 2 on various academic benchmarks. All of the training code, along with our methodology to achieve this throughput, can be found in this blog. We also share the configuration knobs that work well for the Llama 2 models – 7B, 13B, 34B, and 70B for A100s and H100s.

In this process, we also propose a _new _selective activation checkpointing mechanism that applies to FSDP which gives us a 10% boost beyond out-of-the box FSDP. We have open sourced the training code base and an associated scalable data loader as the methodology to achieve this throughput.

One key benefit of a PyTorch native pathway for training is the ability to seamlessly train on multiple hardware backends. For example, the recent end-to-end stack for training that was released by AllenAI through OLMo also leverages PyTorch FSDP for training on AMD and NVIDIA GPUs. There are three main components that we leverage from FSDP to achieve our throughput:

  1. SDPA Flash attention, that enables fused attention kernels and efficient attention computation
  2. Overlap in computation and communication allows for better utilization of the GPU
  3. Selective activation checkpointing enables us to tradeoff between GPU memory and compute

IBM has been working closely with Team PyTorch at Meta on PyTorch FSDP for nearly two years: introducing the rate limiter for achieving better throughput on Ethernet interconnects, distributed checkpointing to improve the checkpoint times by an order of magnitude, and implementing the early version of checkpointing for the hybrid sharding mode of FSDP. Late last year, we used FSDP to train a model end-to-end.

Training Details

The 7B model is trained on 128 A100 GPUs with 400Gbps network connectivity and GPU direct RDMA. We use SDPA FlashAttention v2 for attention computation, and for this model we turned off activation checkpointing that limits the batch size, but provides the highest throughput – batch size is 1 million tokens per batch for 128 GPUs and improves throughput by about 10% when compared to activation checkpointing. With these parameters, we have an almost full overlap in computation and communication. We use the AdamW optimizer in 32-bit with beta1 of 0.9 and beta2 of 0.95, weight decay of 0.1, and a learning rate ending at 3e-5 with a warmup to max learning rate of 3e-4 and a cosine schedule to reduce to 3e-5 over 2T tokens. The training was performed using mixed precision bf16 on an internal dataset. The training stack is using IBM’s Foundation Model Stack for model architecture and PyTorch nightlies post-2.2 release for FSDP and SDPA. We tried a few different nightlies during the time period of Nov 2023 through Feb 2024 and we observed an improvement in the throughput.

Selective activation checkpointing

We jointly implemented a simple and effective mechanism of selective activation checkpointing (AC). In FSDP, the common practice is to checkpoint each transformer block. A simple extension is to checkpoint every _n _blocks and reduce the amount of recomputation, while increasing the memory needed. This is quite effective for the 13B model size, increasing the throughput by 10%. For the 7B model size, we did not need activation checkpointing at all. Future versions of FSDP will provide selective activation checkpointing at an operator level, enabling an optimal compute-memory tradeoff. The code for the above is implemented here.

Throughput and MFU, HFU computation

While we only trained the 7B model to 2T tokens, we performed numerous experiments on the other model sizes to provide the best configuration options. This is summarized in the table below for two types of infrastructure — an A100 cluster with 128 GPUs and 400Gbps inter-node interconnect, and an H100 cluster with 96 GPUs and 800Gbps inter-node interconnect.

Model size

Batch size

Activation checkpoint

Throughput tokens/sec/GPU (A100 80GB and 400Gbps interconnect)

MFU % (A100 80GB)

HFU % (A100 80GB)

Throughput tokens/sec/GPU (H100 80GB and 800Gbps interconnect)

MFU % (H100 80GB)

HFU % (H100 80GB)

7B

2

No

3700

0.57

0.57

7500

0.37

0.37

13B

2

Selective

1800

0.51

0.59

3800

0.35

0.40

34B

2

Yes

700

0.47

0.64

1550

0.32

0.44

70B

2

Yes

370

0.50

0.67

800

0.34

0.45

Table 1: Model and Hardware FLOPS utilization of various model sizes on A100 and H100 GPUs

HFU numbers are computed using the PyTorch FLOP counter and the theoretical bf16 performance of A100 and H100 GPUs, whereas MFU numbers are computed using the methodology outlined in NanoGPT and the PaLM paper. We also note that the batch sizes we use for the larger models are intentionally kept at 2 per GPU to mimic choices made in training models of 4k sequence length and achieve this up to 512 GPUs without exceeding the 4M tokens popular batch size. Beyond that, we would need tensor parallelism or sequence parallelism.

We note in the table above that for A100s, that activation recomputation causes the MFU to reduce, while HFU increases! With the introduction of better activation checkpointing schemes, we expect MFU to increase and catch up with HFU. However, we observe that for H100s, both MFU and HFU are relatively low. We analyze the PyTorch profile traces on H100 and observe that there is a 10% gap due to network “peeking” out. In addition, we hypothesize that the HBM bandwidth of H100s is the cause for the reduced HFU/MFU on H100s and not being able to obtain the 3x improvement (H100s are theoretically 3x faster than A100s – 312 vs 989TFLOPS, but only have <2x the HBM bandwidth than A100s – 2.0 vs 3.35TBps). We plan to try out other configuration options like Tensor Parallel to improve the knobs for the 70B model on H100s.

Model details

The loss curve for training is shown in the below figure.

loss curve for training

Figure 1: LlamaT training loss curve

The 2T checkpoint is converted to Hugging Face format by a script that is provided in the repository and we then use lm-evaluation-harness to compute key academic benchmarks and compare that by running it on Llama2-7B. These results are captured in the below table.

Evaluation metric Llama2-7B (baseline) LlamaT-7B
MMLU (zero shot) 0.41 0.43
MMLU (5-shot weighted avg) 0.47 0.50
Arc challenge 0.46 0.44
Arc easy 0.74 0.71
Boolq 0.78 0.76
Copa 0.87 0.83
Hellaswag 0.76 0.74
Openbookqa 0.44 0.42
Piqa 0.79 0.79
Sciq 0.91 0.91
Winogrande 0.69 0.67
Truthfulqa 0.39 0.39
GSM8k (8-shot) 0.13 0.11

Table 1: LM eval harness scores

We observe that the model performs competitively with Llama2 (bolder is better).

Training chronicles

Training was stable with no crashes, though we did observe a few hiccups:

0-200B tokens: We observed a slowdown in the iteration time (time taken to execute one training step). We stopped the job to ensure that the data loader was not causing any slowdowns and the checkpointing was performant and accurate. We did not find any issues. By this time, HSDP checkpointing code was available in PyTorch, and we took this opportunity to make the switch to PyTorch checkpointing code.

200B tokens-1.9T: We did not do any manual intervention in the job in late December. When we came back early January, disk space had exceeded and checkpoints were failing to be written, although the training job continued. The last known checkpoint was 1.5T.

1.5T-1.7T: We evaluated the 1.5T checkpoint with lm-evaluation-harness and discovered that model has been trained with an extra special token between two documents due to the Hugging Face tokenizer introducing a separator token and our dataloader also appending its own document separator. We modified the dataloader to eliminate the extra special token, and continued training with the modified dataloader from 1.7T token onwards.

1.7T-2T: The loss initially spiked due to the change in the special tokens which was quickly recovered in a few billion tokens. The training finished without any other manual intervention!

Key takeaways and even more speed

We demonstrated how one can use FSDP to train a model to 2T tokens with an excellent performance of 3700 tokens/sec/GPU and that generates a good quality model. As part of this exercise, we open sourced all our code for training and the knobs to achieve this throughput. These knobs can be leveraged by not only large-scale runs, but also smaller scale tuning runs. You can find the code here.

FSDP APIs implement the ZeRO algorithms in a PyTorch native manner and allow for tuning and training of large models. In the past, we have seen FSDP proof points (Stanford Alpaca, Hugging Face, Llama 2 recipes) on tuning a variety of LLMs (such as Meta Llama 2 7B to 70B Llama) using simple training loops and achieving good throughputs and training times.

Finally, we note that there are several levers for speeding up training:

  1. Node optimizations that can speedup specific operations (e.g., attention computation using Flash Attention V2)
  2. Graph optimizations (e.g., fusing kernels, torch.compile)
  3. Overlap in compute-communications
  4. Activation recomputation

We have leveraged 1, 3, and a variation of 4 in this blog and are working closely with Team PyTorch at Meta to get torch.compile (2) as well as a more advanced version of 4 with per-operator selective activation recomputation. We plan to share a simple formatting code and example data to ingest into our data loader to enable others to use the code base for training of models.

Acknowledgements

There are several teams that have been involved in reaching this proof point and we would like to thank the teams across Meta and IBM. Specifically, we extend our gratitude to the PyTorch distributed team, Facebook Research and Applied AI teams that built the FSDP APIs and made enhancements based on our feedback. We also wish to thank the data team at IBM Research that curated the data corpus used in this exercise and the infrastructure team at IBM Research (especially, Claudia Misale, Shweta Salaria, and Seetharami Seelam) that optimized NCCL and network configurations. By building and leveraging all of these components, we have successfully demonstrated the LlamaT proof point.

The selective activation checkpointing was conceptualized at IBM by Linsong Chu, Davis Wertheimer, Mudhakar Srivatsa, and Raghu Ganti and implemented by Less Wright at Meta.

Special thanks to Stas Bekman and Minjia Zhang, who provided extensive feedback and helped improve the blog. Their insights have been invaluable in highlighting key aspects of optimizing the training and exploring further enhancements.

Appendix

Communication computation overlap

Another key aspect of training in a multi-node setting is the ability to overlap communication and computation. In FSDP, there are multiple opportunities for overlapping – during the FSDP unit gathering phase at forward pass as well as the backward pass computation. Overlapping the gather during forward pass while the computation of the previous unit and overlapping backward computation with the next unit gathering and gradient scattering help improve GPU utilization by nearly 2x. We illustrate this on the 400Gbps network interconnect with A100 80GB GPUs. In the case of HSDP, there is no inter-node traffic during the pre-fetch stage for forward pass and the overlap is only for the backward gradient computation phase. Of course, HSDP is feasible only when the model can be sharded within a single node, limiting the size of models to around 30B parameters.

The below figure shows three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half of the image. For the 7B model with no activation recomputation, we observe the overlap to be complete. In practice, the overlap percentage possible is 90% since the first block during forward pass and the last block during backward pass are not able to overlap.

three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half

A zoomed in view of the above three-step process is shown below for a single step. We can clearly see the granularity of the computation and communication and how they overlap in an interleaved manner.

zoomed in view of the above three-step process

Read More

Talk like a graph: Encoding graphs for large language models

Talk like a graph: Encoding graphs for large language models

Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term graph is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way.

Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses.

Since graphs are everywhere and LLM technology is on the rise, in “Talk like a Graph: Encoding Graphs for Large Language Models”, presented at ICLR 2024, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called GraphQA to study different approaches on different graph reasoning problems and show how to phrase a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks!

Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM.

Graphs as text

To be able to systematically find out what is the best way to translate a graph to text, we first design a benchmark called GraphQA. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world.

Overview of our framework for reasoning with graphs using LLMs.

GraphQA focuses on simple tasks related to graphs, like checking if an edge exists, calculating the number of nodes or edges, finding nodes that are connected to a specific node, and checking for cycles in a graph. These tasks might seem basic, but they require understanding the relationships between nodes and edges. By covering different types of challenges, from identifying patterns to creating new connections, GraphQA helps models learn how to analyze graphs effectively. These basic tasks are crucial for more complex reasoning on graphs, like finding the shortest path between nodes, detecting communities, or identifying influential nodes. Additionally, GraphQA includes generating random graphs using various algorithms like Erdős-Rényi, scale-free networks, Barabasi-Albert model, and stochastic block model, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training.

When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand. Prompting heuristics are different strategies for doing this. Let’s break down the common ones:

  • Zero-shot: simply describe the task (“Is there a cycle in this graph?”) and tell the LLM to go for it. No examples provided.
  • Few-shot: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers.
  • Chain-of-Thought: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own “thought process” when faced with new graphs.
  • Zero-CoT: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like “Let’s think step-by-step,” to trigger its own problem-solving breakdown.
  • BAG (build a graph): This is specifically for graph tasks. We add the phrase “Let’s build a graph…” to the description, helping the LLM focus on the graph structure.

We explored different ways to translate graphs into text that LLMs can work with. Our key questions were:

  • Node encoding: How do we represent individual nodes? Options tested include simple integers, common names (people, characters), and letters.
  • Edge encoding: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like “are friends”, and symbolic representations like arrows.

Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure:

Examples of graph encoding functions used to encode graphs via text.

Analysis and results

We carried out three key experiments: one to test how LLMs handle graph tasks, and two to understand how the size of the LLM and different graph shapes affected performance. We run all our experiments on GraphQA.

How LLMs handle graph tasks

In this experiment, we tested how well pre-trained LLMs tackle graph problems like identifying connections, cycles, and node degrees. Here is what we learned:

  • LLMs struggle: On most of these basic tasks, LLMs did not do much better than a random guess.
  • Encoding matters significantly: How we represent the graph as text has a great effect on LLM performance. The “incident” encoding excelled for most of the tasks in general.

Our results are summarized in the following chart.

Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly.

Bigger is (usually) better

In this experiment, we wanted to see if the size of the LLM (in terms of the number of parameters) affects how well they can handle graph problems. For that, we tested the same graph tasks on the XXS, XS, S, and L sizes of PaLM 2. Here is a summary of our findings:

  • In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns.
  • Oddly, size didn’t matter as much for the “edge existence” task (finding out if two nodes in a graph are connected).
  • Even the biggest LLM couldn’t consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks.
Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L.

Do different graph shapes confuse LLMs

We wondered if the “shape” of a graph (how nodes are connected) influences how well LLMs can solve problems on it. Think of the following figure as different examples of graph shapes.

Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

We found that graph structure has a big impact on LLM performance. For example, in a task asking if a cycle exists, LLMs did great on tightly interconnected graphs (cycles are common there) but struggled on path graphs (where cycles never happen). Interestingly, providing some mixed examples helped it adapt. For instance, for cycle check, we added some examples containing a cycle and some examples with no cycles as few-shot examples in our prompt. Similar patterns occurred with other tasks.

Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

Conclusion

In short, we dug deep into how to best represent graphs as text so LLMs can understand them. We found three major factors that make a difference:

  • How to translate the graph to text: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general..
  • Task type: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text.
  • Graph structure: Surprisingly, the “shape” of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does.

This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM’s accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area.

Acknowledgements

We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.

Read More

Head of the Class: Explore AI’s Potential in Higher Education and Research at GTC

Head of the Class: Explore AI’s Potential in Higher Education and Research at GTC

For students, researchers and educators eager to delve into AI, GTC — NVIDIA’s conference on AI and accelerated computing — is in a class of its own.

Taking place from March 18-21 at the San Jose Convention Center, GTC features over 900 talks presented by world-renowned experts in fields such as generative AI, high performance computing, healthcare, energy and environment and robotics.

See some of the top sessions for attendees in higher education below. And don’t miss NVIDIA founder and CEO Jensen Huang’s GTC keynote on how AI is transforming industries, on Monday, March 18, at 1 p.m. PT.

For Researchers 

See more sessions for researchers.

For Educators

Find more sessions for educators.

For Students

Discover more sessions for students and apply to join the NVIDIA Student Network.

To gain hands-on experience, check out training labs and full-day technical workshops at GTC.

Read More