Preconditioned Newton–Schulz (Turbo-Muon Variant)
This repository provides a Triton-accelerated implementation of the Preconditioned Newton–Schulz (NS) iteration used for fast approximate polar decomposition / orthogonalization. The method follows the Turbo-Muon formulation (Boissin, Massena et al., 2025), including AOL preconditioning fused directly into the first iteration.
The implementation is designed for large matrices on modern accelerators, with minimal memory traffic and fully fused Triton kernels.
Usage
Basic example
import torch
from kernels import get_kernel
kern = get_kernel("tboissin/newton_schulz_triton")
newton_schulz = torch.compile(kern.newton_schulz)
n = 4096
x = torch.randn((16, n, n), device="cuda", dtype=torch.bfloat16)
# Preconditioned (recommended)
out = newton_schulz(x, iter=4, precondition=True, epsilon=1e-7)
# Without preconditioning (requires 5 steps to match accuracy)
out2 = newton_schulz(x, iter=5, precondition=False, epsilon=1e-7)
# Orthogonality error
err = torch.norm(out @ out.mT - torch.eye(n, device="cuda", dtype=torch.bfloat16), dim=(1,2)).mean()
err2 = torch.norm(out2 @ out2.mT - torch.eye(n, device="cuda", dtype=torch.bfloat16), dim=(1,2)).mean()
print("Orthogonality error (preconditioned):", err.item())
print("Orthogonality error (not preconditioned):", err2.item())
Function Overview
Default is newton_schulz(G, iter=4, precondition=True, epsilon=1e-7, dtype=torch.bfloat16).
This routine implements the Turbo-Muon variant of the Newton–Schulz (NS) method with AOL preconditioning, as described in “Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning” (Boissin, Massena et al., 2025). It serves as the fast orthogonalization step used in Muon-style optimizers.
Given an input matrix (or batch) G ∈ ℝ^{…, m, n}, the function computes an approximate polar factor X whose rows or columns are close to orthonormal and span the same subspace as G. For non-square matrices, the algorithm returns the closest pseudo-orthogonal factor in the Stiefel sense:
- XᵀX = I when m ≥ n
- XXᵀ = I when m ≤ n
Triton Kernels
The implementation uses three Triton kernels:
ns_line_1(X, out=A)Computes the Gram matrix A = X @ Xᵀns_line_2(A, alpha=c, beta=b, out=B)Applies a 4th-order matrix polynomial: B = bA + cA²ns_line_3(B, X, a, out=C)Performs the NS update step: C = aX + B @ X
Together, these realize one Newton–Schulz iteration:
[ X_{k+1} = a_k X_k + b_k X_k X_k^{T} X_k + c_k X_k X_k^{T} X_k X_k^{T} X_k ]
The coefficients ((a_k, b_k, c_k)) are drawn from a small fixed table optimized for up to five iterations.
Parameters
G : torch.Tensor
Input tensor with shape (..., m, n). The last two dimensions are treated as a matrix; leading dimensions represent batch dimensions. The device follows that of the input.
iter : int, default = 4
Number of Newton–Schulz iterations (1–5). Turbo-Muon typically uses 4 iterations, achieving accuracy comparable to classical 5-step NS but faster.
precondition : bool, default = True
Enables AOL preconditioning, fused with the first NS iteration:
- Compute A₀ = X₀ X₀ᵀ via
ns_line_1. - Compute per-row scaling [ s = \frac{1}{\sqrt{\sum_j |A₀[i,j]| + \epsilon}} ]
- Rescale X₀ ← X₀ · s[..., None] to improve conditioning.
- Update the Gram matrix in place: A₁ = sᵀ A₀ s
- Use A₁ directly in the first NS step.
AOL preconditioning enforces safe spectral scale and improves conditioning, allowing one fewer NS iteration.
If disabled, the input is normalized by approximate Frobenius norm:
[ X₀ = \frac{G}{|G|_F + \epsilon} ]
epsilon : float, default = 1e-7
Lower bound for denominators to ensure numerical stability.
dtype : torch.dtype, default = torch.bfloat16
Working dtype for all computations. Supports BF16 for performance or FP32 for higher precision.
Returns
X : torch.Tensor
Approximate polar factor of G, with the same shape and dtype. If the input is tall (m > n), the routine transposes internally to operate on a wide matrix, then transposes back.
The output satisfies approximately:
- XᵀX = I for m ≥ n
- XXᵀ = I for m ≤ n
and spans the same column/row space as G.
Implementation Details
Uses
torch.compile()for graph-level optimization.Coefficients
(a, b, c)are stored inns_consts;ns_consts[-iter:]selects the appropriate schedule.When preconditioning is enabled, the first triple is consumed by the fused AOL-NS step.
Three work buffers are reused to avoid allocation overhead:
- A for Gram matrices
- B for polynomial outputs
- C for updated iterates
The design mirrors the Turbo-Muon formulation, matching the polar accuracy of optimized 5-step Muon at lower runtime.
Performance Notes
- Designed for GPUs, f16 and bf16 are supported.
- Preconditionning is intended for approximate estimation on large matrices.
torch.compileis supported and advised (kernel is not compiled by default).- an autograd compatible version is available as
newton_schulz_autograd. This feature is experimental.
Citation
Boissin, Massena et al., 2025.
Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning.
License
MIT License.
Credits for ns_line_1 and ns_line_2: https://github.com/microsoft/dion which is licensed under the MIT License.