Skip to content

Conversation

@Rohanjames1997
Copy link
Contributor

@Rohanjames1997 Rohanjames1997 commented Dec 19, 2025

Description

This PR adds a BF16 (bfloat16) pointwise convolution kernel for ARM64 NCHWc format, leveraging the existing SBGEMM infrastructure. When the mlas.enable_gemm_fastmath_arm64_bfloat16 session option is enabled on supported ARM64 Linux hardware, Pointwise Conv is rerouted to use this BF16 implementation. This is an opt-in feature, similar to how BF16 matmul is opt-in.

Added a bool ZeroMode field to MLAS_SBGEMM_DATA_PARAMS (default true for backward compatibility) to enable per-batch control over output accumulation. This mirrors the beta parameter in FP32's MlasGemmBatch and is required for Pointwise convolutions with >128 input channels, where multiple GEMM calls must accumulate into the same output buffer.

Motivation and Context

The existing mlas.enable_gemm_fastmath_arm64_bfloat16 session option accelerates MatMul operations on ARM64 processors with BF16 support, but convolution operations did not benefit from this optimization. Pointwise convolutions (1x1 kernels) are essentially batched matrix multiplications.

This change extends the BF16 fastmath optimization to pointwise NCHWc convolutions, reusing the same session option. The implementation mirrors the FP32 pointwise kernel structure while delegating the actual computation to SBGEMM, ensuring correctness and maintainability.

Performance improvement

Measured a 15-20% gain on Mobilenet inference on an AWS Graviton4 instance.

Before (FP32)

/build/Linux/Release/onnxruntime_perf_test -C "mlas.enable_gemm_fastmath_arm64_bfloat16|0" -x 32 -I -m times -r 2000 ~/scripts/mobilenet.onnx

Number of inferences per second: 559.154

After (BF16)

./build/Linux/Release/onnxruntime_perf_test -C "mlas.enable_gemm_fastmath_arm64_bfloat16|1" -x 32 -I -m times -r 2000 ~/scripts/mobilenet.onnx

Number of inferences per second: 651.221

@Rohanjames1997
Copy link
Contributor Author

@hariharans29 another PR that's up your alley.

Can you request a preliminary review from Copilot & run CI?

Thanks!

@hariharans29 hariharans29 requested a review from Copilot December 21, 2025 05:23
@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends BF16 (bfloat16) precision optimization support to pointwise (1x1) NCHWc convolutions on ARM64 Linux platforms. The implementation leverages the existing SBGEMM infrastructure and the mlas.enable_gemm_fastmath_arm64_bfloat16 session option, delivering a reported 15-20% performance improvement on Mobilenet inference.

Key changes:

  • Adds BF16 pointwise convolution kernel (MlasConvPointwiseBf16KernelNeon) that delegates computation to SBGEMM
  • Introduces ZeroMode field to MLAS_SBGEMM_DATA_PARAMS to enable accumulation control across multiple GEMM calls
  • Routes pointwise convolutions to BF16 implementation when fastmath mode is enabled on supported hardware

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp New BF16 pointwise convolution kernel implementation using SBGEMM batch operations
onnxruntime/core/mlas/inc/mlas.h Adds UseBf16 parameter to MlasNchwcConv API and ZeroMode field to MLAS_SBGEMM_DATA_PARAMS
onnxruntime/core/mlas/lib/sbgemm.h Propagates ZeroMode parameter through SBGEMM packed/non-packed operations
onnxruntime/core/mlas/lib/snchwc.cpp Adds UseBf16 parameter and conditional BF16 kernel selection logic
onnxruntime/core/mlas/lib/mlasi.h Declares MlasConvPointwiseBf16KernelNeon and adds ConvPointwiseBf16Kernel to platform struct
onnxruntime/core/mlas/lib/platform.cpp Initializes BF16 kernel pointer in ARM64 NEON platform initialization
onnxruntime/contrib_ops/cpu/nchwc_ops.h Adds fastmath mode detection in constructor and member variable
onnxruntime/contrib_ops/cpu/nchwc_ops.cc Passes BF16 flag to MlasNchwcConv based on session options
cmake/onnxruntime_mlas.cmake Adds new source file with ARM BF16 compilation flags

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Rohanjames1997
Copy link
Contributor Author

Rohanjames1997 commented Dec 22, 2025

Thanks @hariharans29 !
Looks like the failures are due to inconsistent ifdefs(?). I'm looking into it.
Do let me know if you have ideas too, but I may need you to rerun CI a few times more after I push fixes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants