-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Introducing BF16 Pointwise NCHWc Convolution for Arm64 #26838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@hariharans29 another PR that's up your alley. Can you request a preliminary review from Copilot & run CI? Thanks! |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR extends BF16 (bfloat16) precision optimization support to pointwise (1x1) NCHWc convolutions on ARM64 Linux platforms. The implementation leverages the existing SBGEMM infrastructure and the mlas.enable_gemm_fastmath_arm64_bfloat16 session option, delivering a reported 15-20% performance improvement on Mobilenet inference.
Key changes:
- Adds BF16 pointwise convolution kernel (
MlasConvPointwiseBf16KernelNeon) that delegates computation to SBGEMM - Introduces
ZeroModefield toMLAS_SBGEMM_DATA_PARAMSto enable accumulation control across multiple GEMM calls - Routes pointwise convolutions to BF16 implementation when fastmath mode is enabled on supported hardware
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp |
New BF16 pointwise convolution kernel implementation using SBGEMM batch operations |
onnxruntime/core/mlas/inc/mlas.h |
Adds UseBf16 parameter to MlasNchwcConv API and ZeroMode field to MLAS_SBGEMM_DATA_PARAMS |
onnxruntime/core/mlas/lib/sbgemm.h |
Propagates ZeroMode parameter through SBGEMM packed/non-packed operations |
onnxruntime/core/mlas/lib/snchwc.cpp |
Adds UseBf16 parameter and conditional BF16 kernel selection logic |
onnxruntime/core/mlas/lib/mlasi.h |
Declares MlasConvPointwiseBf16KernelNeon and adds ConvPointwiseBf16Kernel to platform struct |
onnxruntime/core/mlas/lib/platform.cpp |
Initializes BF16 kernel pointer in ARM64 NEON platform initialization |
onnxruntime/contrib_ops/cpu/nchwc_ops.h |
Adds fastmath mode detection in constructor and member variable |
onnxruntime/contrib_ops/cpu/nchwc_ops.cc |
Passes BF16 flag to MlasNchwcConv based on session options |
cmake/onnxruntime_mlas.cmake |
Adds new source file with ARM BF16 compilation flags |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Thanks @hariharans29 ! |
Description
This PR adds a BF16 (bfloat16) pointwise convolution kernel for ARM64 NCHWc format, leveraging the existing SBGEMM infrastructure. When the
mlas.enable_gemm_fastmath_arm64_bfloat16session option is enabled on supported ARM64 Linux hardware, Pointwise Conv is rerouted to use this BF16 implementation. This is an opt-in feature, similar to how BF16 matmul is opt-in.Added a bool ZeroMode field to
MLAS_SBGEMM_DATA_PARAMS(defaulttruefor backward compatibility) to enable per-batch control over output accumulation. This mirrors the beta parameter in FP32'sMlasGemmBatchand is required for Pointwise convolutions with >128 input channels, where multiple GEMM calls must accumulate into the same output buffer.Motivation and Context
The existing
mlas.enable_gemm_fastmath_arm64_bfloat16session option accelerates MatMul operations on ARM64 processors with BF16 support, but convolution operations did not benefit from this optimization. Pointwise convolutions (1x1 kernels) are essentially batched matrix multiplications.This change extends the BF16 fastmath optimization to pointwise NCHWc convolutions, reusing the same session option. The implementation mirrors the FP32 pointwise kernel structure while delegating the actual computation to SBGEMM, ensuring correctness and maintainability.
Performance improvement
Measured a 15-20% gain on Mobilenet inference on an AWS Graviton4 instance.
Before (FP32)
After (BF16)