YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Preconditioned Newton–Schulz (Turbo-Muon Variant)

This repository provides a Triton-accelerated implementation of the Preconditioned Newton–Schulz (NS) iteration used for fast approximate polar decomposition / orthogonalization. The method follows the Turbo-Muon formulation (Boissin, Massena et al., 2025), including AOL preconditioning fused directly into the first iteration.

The implementation is designed for large matrices on modern accelerators, with minimal memory traffic and fully fused Triton kernels.


Usage

Basic example

import torch
from kernels import get_kernel

kern = get_kernel("tboissin/newton_schulz_triton")
newton_schulz = torch.compile(kern.newton_schulz)

n = 4096
x = torch.randn((16, n, n), device="cuda", dtype=torch.bfloat16)

# Preconditioned (recommended)
out = newton_schulz(x, iter=4, precondition=True, epsilon=1e-7)

# Without preconditioning (requires 5 steps to match accuracy)
out2 = newton_schulz(x, iter=5, precondition=False, epsilon=1e-7)

# Orthogonality error
err = torch.norm(out @ out.mT - torch.eye(n, device="cuda", dtype=torch.bfloat16), dim=(1,2)).mean()
err2 = torch.norm(out2 @ out2.mT - torch.eye(n, device="cuda", dtype=torch.bfloat16), dim=(1,2)).mean()

print("Orthogonality error (preconditioned):", err.item())
print("Orthogonality error (not preconditioned):", err2.item())

Function Overview

Default is newton_schulz(G, iter=4, precondition=True, epsilon=1e-7, dtype=torch.bfloat16).

This routine implements the Turbo-Muon variant of the Newton–Schulz (NS) method with AOL preconditioning, as described in “Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning” (Boissin, Massena et al., 2025). It serves as the fast orthogonalization step used in Muon-style optimizers.

Given an input matrix (or batch) G ∈ ℝ^{…, m, n}, the function computes an approximate polar factor X whose rows or columns are close to orthonormal and span the same subspace as G. For non-square matrices, the algorithm returns the closest pseudo-orthogonal factor in the Stiefel sense:

  • XᵀX = I when m ≥ n
  • XXᵀ = I when m ≤ n

Triton Kernels

The implementation uses three Triton kernels:

  • ns_line_1(X, out=A) Computes the Gram matrix A = X @ Xᵀ

  • ns_line_2(A, alpha=c, beta=b, out=B) Applies a 4th-order matrix polynomial: B = bA + cA²

  • ns_line_3(B, X, a, out=C) Performs the NS update step: C = aX + B @ X

Together, these realize one Newton–Schulz iteration:

[ X_{k+1} = a_k X_k + b_k X_k X_k^{T} X_k + c_k X_k X_k^{T} X_k X_k^{T} X_k ]

The coefficients ((a_k, b_k, c_k)) are drawn from a small fixed table optimized for up to five iterations.


Parameters

G : torch.Tensor

Input tensor with shape (..., m, n). The last two dimensions are treated as a matrix; leading dimensions represent batch dimensions. The device follows that of the input.

iter : int, default = 4

Number of Newton–Schulz iterations (1–5). Turbo-Muon typically uses 4 iterations, achieving accuracy comparable to classical 5-step NS but faster.

precondition : bool, default = True

Enables AOL preconditioning, fused with the first NS iteration:

  1. Compute A₀ = X₀ X₀ᵀ via ns_line_1.
  2. Compute per-row scaling [ s = \frac{1}{\sqrt{\sum_j |A₀[i,j]| + \epsilon}} ]
  3. Rescale X₀ ← X₀ · s[..., None] to improve conditioning.
  4. Update the Gram matrix in place: A₁ = sᵀ A₀ s
  5. Use A₁ directly in the first NS step.

AOL preconditioning enforces safe spectral scale and improves conditioning, allowing one fewer NS iteration.

If disabled, the input is normalized by approximate Frobenius norm:

[ X₀ = \frac{G}{|G|_F + \epsilon} ]

epsilon : float, default = 1e-7

Lower bound for denominators to ensure numerical stability.

dtype : torch.dtype, default = torch.bfloat16

Working dtype for all computations. Supports BF16 for performance or FP32 for higher precision.


Returns

X : torch.Tensor

Approximate polar factor of G, with the same shape and dtype. If the input is tall (m > n), the routine transposes internally to operate on a wide matrix, then transposes back.

The output satisfies approximately:

  • XᵀX = I for m ≥ n
  • XXᵀ = I for m ≤ n

and spans the same column/row space as G.


Implementation Details

  • Uses torch.compile() for graph-level optimization.

  • Coefficients (a, b, c) are stored in ns_consts; ns_consts[-iter:] selects the appropriate schedule.

  • When preconditioning is enabled, the first triple is consumed by the fused AOL-NS step.

  • Three work buffers are reused to avoid allocation overhead:

    • A for Gram matrices
    • B for polynomial outputs
    • C for updated iterates
  • The design mirrors the Turbo-Muon formulation, matching the polar accuracy of optimized 5-step Muon at lower runtime.


Performance Notes

  • Designed for GPUs, f16 and bf16 are supported.
  • Preconditionning is intended for approximate estimation on large matrices.
  • torch.compile is supported and advised (kernel is not compiled by default).
  • an autograd compatible version is available as newton_schulz_autograd. This feature is experimental.

Citation

Boissin, Massena et al., 2025.
Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning.

License

MIT License.

Credits for ns_line_1 and ns_line_2: https://github.com/microsoft/dion which is licensed under the MIT License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support