Skip to content

API Reference

Complete auto-generated API reference for all Triton-Augment operations.


Transform Classes

Transform classes provide stateful, random augmentations similar to torchvision. Recommended for training pipelines.

triton_augment.TritonFusedAugment

Bases: Module

Fused augmentation: All operations in ONE kernel.

This transform combines ALL augmentations in a single GPU kernel launch. Unified Fusion: Combines Affine (Rotation/Translation/Scale/Shear) + Crop + Flip + Color Jitter + Grayscale + Normalize in order into a single kernel launch.

Performance: up to 14x faster than torchvision.transforms.Compose!

Parameters:

Name Type Description Default
crop_size int | tuple[int, int] | None

Desired output size (int or tuple). If int, output is square (crop_size, crop_size). If None, output size equals input size (no cropping). Default: None.

None
horizontal_flip_p float

Probability of horizontal flip (default: 0.0, no flip)

0.0
degrees float | tuple[float, float]

Rotation degrees range. If float, range is (-degrees, +degrees). If tuple, range is (degrees[0], degrees[1]). Default: 0 (no rotation).

0
translate tuple[float, float] | None

Translation range as fraction of image size (tx, ty). E.g., (0.1, 0.1) means translate up to 10% of width/height. Default: None (no translation).

None
scale tuple[float, float] | None

Scale range (min, max). E.g., (0.8, 1.2) scales between 80% and 120%. Default: None (no scaling).

None
shear float | tuple[float, float] | None

Shear range in degrees. If float, range is (-shear, +shear) for x-axis. If tuple of 2, (shear[0], shear[1]) for x-axis. If tuple of 4, (shear[0], shear[1]) for x-axis and (shear[2], shear[3]) for y-axis. Default: None (no shearing).

None
interpolation

Interpolation mode for affine (InterpolationMode.NEAREST or BILINEAR). Only used in affine mode. Default: NEAREST.

NEAREST
fill float

Fill value for out-of-bounds pixels in affine mode. Default: 0.0.

0.0
brightness float | tuple[float, float]

How much to jitter brightness. If float, chosen uniformly from [max(0, 1-brightness), 1+brightness]. If tuple, chosen uniformly from [brightness[0], brightness[1]].

0
contrast float | tuple[float, float]

How much to jitter contrast (same format as brightness)

0
saturation float | tuple[float, float]

How much to jitter saturation (same format as brightness)

0
grayscale_p float

Probability of converting to grayscale (default: 0.0, no grayscale)

0.0
mean Optional[tuple[float, float, float]]

Sequence of means for R, G, B channels. If None, normalization is skipped.

None
std Optional[tuple[float, float, float]]

Sequence of stds for R, G, B channels. If None, normalization is skipped.

None
same_on_batch bool

If True, all images in batch (N dimension) share the same random parameters. If False (default), each image gets different random parameters.

False
same_on_frame bool

If True, all frames in a video (T dimension) share the same random parameters. If False, each frame gets different random parameters. Only applies to 5D input [N, T, C, H, W]. Default: True (consistent augmentation across frames).

True
Example
# Simple mode (crop + flip + color)
transform = ta.TritonFusedAugment(
    crop_size=112,
    horizontal_flip_p=0.5,
    brightness=0.2,
    contrast=0.2,
    saturation=0.2,
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225)
)

# Affine mode (rotation + scale + color)
transform = ta.TritonFusedAugment(
    crop_size=224,
    degrees=30,           # Enables affine mode
    scale=(0.8, 1.2),
    horizontal_flip_p=0.5,
    brightness=0.2,
    interpolation=InterpolationMode.BILINEAR
)

img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img)  # Single kernel launch!
Note
  • Uses FAST contrast (centered scaling), not torchvision's blend-with-mean
  • Input must be (C, H, W), (N, C, H, W), or (N, T, C, H, W) float tensor on CUDA in [0, 1] range

Functions

forward

forward(image: Tensor) -> torch.Tensor

Apply all augmentations in a single fused kernel.

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W) Can be on CPU or CUDA (will be moved to CUDA automatically)

required

Returns:

Type Description
Tensor

Augmented tensor of same shape and device as input


triton_augment.TritonColorJitterNormalize

Bases: Module

Combined color jitter, random grayscale, and normalization in a single fused operation.

This class combines TritonColorJitter, TritonRandomGrayscale, and TritonNormalize into a single operation that uses a fused kernel for maximum performance. This is the recommended way to apply color augmentations and normalization.

Parameters:

Name Type Description Default
brightness Optional[Union[float, Sequence[float]]]

How much to jitter brightness (same as TritonColorJitter)

None
contrast Optional[Union[float, Sequence[float]]]

How much to jitter contrast (same as TritonColorJitter)

None
saturation Optional[Union[float, Sequence[float]]]

How much to jitter saturation (same as TritonColorJitter)

None
grayscale_p float

Probability of converting to grayscale (default: 0.0)

0.0
mean Optional[Tuple[float, float, float]]

Sequence of means for normalization (R, G, B). If None, normalization is skipped.

None
std Optional[Tuple[float, float, float]]

Sequence of standard deviations for normalization (R, G, B). If None, normalization is skipped.

None
same_on_batch bool

If True, all images in batch share the same random parameters If False (default), each image in batch gets different random parameters

False
Example
# Full augmentation pipeline in one transform (per-image randomness)
transform = TritonColorJitterNormalize(
    brightness=0.2,  # Range: [0.8, 1.2]
    contrast=0.2,    # Range: [0.8, 1.2]
    saturation=0.2,  # Range: [0.8, 1.2]
    grayscale_p=0.1,  # 10% chance of grayscale (per-image)
    mean=(0.485, 0.456, 0.406),  # ImageNet normalization (optional)
    std=(0.229, 0.224, 0.225),    # ImageNet normalization (optional)
    same_on_batch=False
)
img = torch.rand(4, 3, 224, 224, device='cuda')
augmented = transform(img)  # Each image gets different augmentation

# Without normalization (mean=None, std=None by default)
transform_no_norm = TritonColorJitterNormalize(
    brightness=0.2, contrast=0.2, saturation=0.2
)

Functions

forward

forward(img: Tensor) -> torch.Tensor

Apply random color jitter and normalization in a single fused operation.

Parameters:

Name Type Description Default
img Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W) Can be on CPU or CUDA (will be moved to CUDA automatically)

required

Returns:

Type Description
Tensor

Augmented and normalized tensor of same shape and device as input


triton_augment.TritonRandomCropFlip

Bases: Module

Fused random crop + random horizontal flip.

This class combines random crop and random horizontal flip in a SINGLE kernel launch, eliminating intermediate memory transfers.

Performance: ~1.5-2x faster than applying TritonRandomCrop + TritonRandomHorizontalFlip sequentially.

Parameters:

Name Type Description Default
size Union[int, Sequence[int]]

Desired output size (height, width) or int for square crop

required
horizontal_flip_p float

Probability of horizontal flip (default: 0.5)

0.5
same_on_batch bool

If True, all images in batch (N dimension) share the same random parameters. If False (default), each image gets different random parameters.

False
same_on_frame bool

If True, all frames in a video (T dimension) share the same random parameters. If False, each frame gets different random parameters. Only applies to 5D input [N, T, C, H, W]. Default: True.

True
Example
# Fused version (FAST - single kernel, per-image randomness)
transform_fused = TritonRandomCropFlip(112, horizontal_flip_p=0.5, same_on_batch=False)
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform_fused(img)  # Each image gets different crop & flip

# Equivalent sequential version (SLOWER - 2 kernels)
transform_seq = nn.Sequential(
    TritonRandomCrop(112, same_on_batch=False),
    TritonRandomHorizontalFlip(p=0.5, same_on_batch=False)
)
result_seq = transform_seq(img)
Note

The fused version uses compile-time branching (tl.constexpr), so there's zero overhead when flip is not triggered.

Functions

forward

forward(image: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W)

required

Returns:

Type Description
Tensor

Randomly cropped (and optionally flipped) tensor of shape matching input


triton_augment.TritonColorJitter

Bases: Module

Randomly change the brightness, contrast, and saturation of an image.

This is a GPU-accelerated version of torchvision.transforms.v2.ColorJitter that uses a fused kernel for maximum performance.

IMPORTANT: Contrast uses FAST mode (centered scaling: (pixel - 0.5) * factor + 0.5), NOT torchvision's blend-with-mean approach. This is much faster and provides similar visual results.

If you need exact torchvision behavior, use the individual functional APIs: - F.adjust_brightness() (exact) - F.adjust_contrast() (torchvision-exact, slower) - F.adjust_saturation() (exact)

Parameters:

Name Type Description Default
brightness Optional[Union[float, Sequence[float]]]

How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers.

None
contrast Optional[Union[float, Sequence[float]]]

How much to jitter contrast (uses FAST mode). contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. Should be non-negative numbers.

None
saturation Optional[Union[float, Sequence[float]]]

How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. Should be non negative numbers.

None
same_on_batch bool

If True, all images in batch (N dimension) share the same random parameters. If False (default), each image gets different random parameters.

False
same_on_frame bool

If True, all frames in a video (T dimension) share the same random parameters. If False, each frame gets different random parameters. Only applies to 5D input [N, T, C, H, W]. Default: True (consistent augmentation across frames).

True
Example
# Basic usage with per-image randomness
transform = TritonColorJitter(
    brightness=0.2,  # Range: [0.8, 1.2]
    contrast=0.2,    # Range: [0.8, 1.2] (FAST contrast)
    saturation=0.2,  # Range: [0.8, 1.2]
    same_on_batch=False
)
img = torch.rand(4, 3, 224, 224, device='cuda')
augmented = transform(img)  # Each image gets different augmentation

# Custom ranges
transform = TritonColorJitter(
    brightness=(0.5, 1.5),  # Custom range
    contrast=(0.7, 1.3),     # Custom range (FAST mode)
    saturation=(0.0, 2.0)    # Custom range
)
Performance
  • Uses fused kernel for all operations in a single pass
  • Faster than sequential operations
  • For even more speed, combine with normalization using TritonColorJitterNormalize

Functions

forward

forward(img: Tensor) -> torch.Tensor

Apply random color jitter to the input image tensor.

Parameters:

Name Type Description Default
img Tensor

Input image tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W)

required

Returns:

Type Description
Tensor

Augmented image tensor of the same shape and dtype


triton_augment.TritonNormalize

Bases: Module

Normalize a tensor image with mean and standard deviation.

This is a GPU-accelerated version of torchvision.transforms.Normalize that uses a Triton kernel for improved performance.

Parameters:

Name Type Description Default
mean Tuple[float, float, float]

Sequence of means for each channel (R, G, B)

required
std Tuple[float, float, float]

Sequence of standard deviations for each channel (R, G, B)

required
Example
# ImageNet normalization
normalize = TritonNormalize(
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225)
)
img = torch.rand(1, 3, 224, 224, device='cuda')
normalized = normalize(img)

Functions

forward

forward(img: Tensor) -> torch.Tensor

Normalize the input image tensor.

Parameters:

Name Type Description Default
img Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W) Can be on CPU or CUDA (will be moved to CUDA automatically)

required

Returns:

Type Description
Tensor

Normalized tensor of same shape and device as input


triton_augment.TritonRandomGrayscale

Bases: Module

Randomly convert image to grayscale with probability p.

Matches torchvision.transforms.v2.RandomGrayscale behavior with optional per-image randomness.

Parameters:

Name Type Description Default
p float

Probability of converting to grayscale (default: 0.1)

0.1
num_output_channels int

Number of output channels (1 or 3, default: 3) Usually 3 to maintain compatibility with RGB pipelines

3
same_on_batch bool

If True, all images in batch (N dimension) make the same grayscale decision. If False (default), each image independently decides grayscale conversion.

False
same_on_frame bool

If True, all frames in a video (T dimension) make the same grayscale decision. If False, each frame independently decides. Only applies to 5D input [N, T, C, H, W]. Default: True.

True
Example
# Per-image randomness (each image independently converted)
transform = TritonRandomGrayscale(p=0.5, num_output_channels=3, same_on_batch=False)
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img)  # Each image has 50% chance of being grayscale

# Batch-wide (all images converted or none)
transform = TritonRandomGrayscale(p=0.5, num_output_channels=3, same_on_batch=True)
result = transform(img)  # Either all 4 images are grayscale or none are

Functions

forward

forward(image: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, 3, H, W), or (N, T, 3, H, W)

required

Returns:

Type Description
Tensor

Image tensor, either original or grayscale based on probability


triton_augment.TritonGrayscale

Bases: Module

Convert image to grayscale.

Matches torchvision.transforms.v2.Grayscale behavior. Uses weights: 0.2989R + 0.587G + 0.114*B

Parameters:

Name Type Description Default
num_output_channels int

Number of output channels (1 or 3). If 1, output is single-channel grayscale. If 3, grayscale is replicated to 3 channels.

1
Example
transform = TritonGrayscale(num_output_channels=3)
img = torch.rand(1, 3, 224, 224, device='cuda')
gray = transform(img)  # Shape: (1, 3, 224, 224), all channels identical

Functions

forward

forward(image: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, 3, H, W), or (N, T, 3, H, W)

required

Returns:

Type Description
Tensor

Grayscale tensor of shape matching input structure


triton_augment.TritonRandomCrop

Bases: Module

Crop a random portion of the image.

Matches torchvision.transforms.v2.RandomCrop behavior (simplified MVP version).

Parameters:

Name Type Description Default
size Union[int, Sequence[int]]

Desired output size (height, width) or int for square crop

required
same_on_batch bool

If True, all images in batch (N dimension) crop at the same position. If False (default), each image gets different random crop position.

False
same_on_frame bool

If True, all frames in a video (T dimension) crop at the same position. If False, each frame gets different random crop. Only applies to 5D input [N, T, C, H, W]. Default: True.

True
Example

transform = TritonRandomCrop(112)
img = torch.rand(4, 3, 224, 224, device='cuda')
cropped = transform(img)
cropped.shape
torch.Size([4, 3, 112, 112])

Note

For MVP, padding is not supported. Image must be larger than crop size. Future versions will support padding, pad_if_needed, fill, padding_mode.

Functions

forward

forward(image: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W)

required

Returns:

Type Description
Tensor

Randomly cropped tensor of shape matching input


triton_augment.TritonCenterCrop

Bases: Module

Crop the center of the image.

Matches torchvision.transforms.v2.CenterCrop behavior.

Parameters:

Name Type Description Default
size Union[int, Sequence[int]]

Desired output size (height, width) or int for square crop

required
Example

transform = TritonCenterCrop(112)
img = torch.rand(4, 3, 224, 224, device='cuda')
cropped = transform(img)
cropped.shape
torch.Size([4, 3, 112, 112])

Functions

forward

forward(image: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W)

required

Returns:

Type Description
Tensor

Center-cropped tensor of shape matching input


triton_augment.TritonRandomAffine

Bases: Module

Random affine transformation of the image keeping center invariant.

GPU-accelerated implementation using Triton kernels. Matches the API of torchvision.transforms.v2.RandomAffine.

Supports
  • 3D input: (C, H, W) - single image
  • 4D input: (N, C, H, W) - batch of images
  • 5D input: (N, T, C, H, W) - batch of videos

Parameters:

Name Type Description Default
degrees Union[float, Sequence[float]]

Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). Set to 0 to deactivate rotations.

required
translate Optional[Tuple[float, float]]

Tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.

None
scale Optional[Tuple[float, float]]

Scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. Will keep original scale by default.

None
shear Optional[Union[float, Sequence[float]]]

Range of degrees to select from. If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) will be applied. Else if shear is a tuple of 2 values, a x-axis shear in (shear[0], shear[1]) will be applied. Else if shear is a tuple of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Default: None.

None
interpolation

Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother.

NEAREST
fill float

Constant fill value for areas outside the transformed image. Default: 0.

0.0
center Optional[Tuple[int, int]]

Optional center of rotation (x, y) in pixel coordinates. Origin is the upper left corner. Default is the center of the image.

None
same_on_batch bool

If True, all images in batch share the same random parameters. Default: False.

False
same_on_frame bool

If True, all frames in a video (5D input) share the same random parameters. Default: True.

True
Note

For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries due to floating-point rounding. Bilinear interpolation does not have this limitation.

Example
transform = TritonRandomAffine(
    degrees=15,
    translate=(0.1, 0.1),
    scale=(0.8, 1.2),
    shear=10,
    interpolation=InterpolationMode.BILINEAR
)

# Apply to batch of images
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img)

# Apply to video with same transform per frame
video = torch.rand(2, 8, 3, 112, 112, device='cuda')
result = transform(video)

Functions

forward

forward(img: Tensor) -> torch.Tensor

Apply affine transformation.


triton_augment.TritonRandomRotation

Bases: TritonRandomAffine

Random rotation of the image.

GPU-accelerated implementation using Triton kernels. Matches the API of torchvision.transforms.v2.RandomRotation.

Supports
  • 3D input: (C, H, W) - single image
  • 4D input: (N, C, H, W) - batch of images
  • 5D input: (N, T, C, H, W) - batch of videos

Parameters:

Name Type Description Default
degrees Union[float, Sequence[float]]

Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees).

required
interpolation

Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother.

NEAREST
expand bool

If True, expands the output to hold the entire rotated image. Currently not supported (raises NotImplementedError).

False
center Optional[Tuple[int, int]]

Optional center of rotation (x, y) in pixel coordinates. Origin is the upper left corner. Default is the center of the image.

None
fill float

Constant fill value for areas outside the rotated image. Default: 0.

0.0
same_on_batch bool

If True, all images in batch share the same random rotation angle. Default: False.

False
same_on_frame bool

If True, all frames in a video (5D input) share the same random rotation angle. Default: True.

True
Note

For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries due to floating-point rounding. Bilinear interpolation does not have this limitation.

Example
transform = TritonRandomRotation(
    degrees=30,
    interpolation=InterpolationMode.BILINEAR,
    fill=0.5
)

# Apply to batch of images
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img)

triton_augment.TritonRandomHorizontalFlip

Bases: Module

Horizontally flip the image randomly with probability p.

Matches torchvision.transforms.v2.RandomHorizontalFlip behavior.

Parameters:

Name Type Description Default
p float

Probability of flipping (default: 0.5)

0.5
same_on_batch bool

If True, all images in batch (N dimension) share the same flip decision. If False (default), each image gets different random decision.

False
same_on_frame bool

If True, all frames in a video (T dimension) share the same flip decision. If False, each frame gets different random decision. Only applies to 5D input [N, T, C, H, W]. Default: True.

True
Example
transform = TritonRandomHorizontalFlip(p=0.5)
img = torch.rand(4, 3, 224, 224, device='cuda')
flipped = transform(img)  # Each image has 50% chance of being flipped

Functions

forward

forward(image: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
image Tensor

Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W)

required

Returns:

Type Description
Tensor

Image tensor, either original or horizontally flipped


Functional API

Low-level functional interface for fine-grained control with fixed parameters. Use when you need deterministic operations.

triton_augment.functional.fused_augment

fused_augment(image: Tensor, top: int | Tensor, left: int | Tensor, height: int, width: int, flip_horizontal: bool | Tensor = False, angle: float | Tensor = 0.0, translate: list[float] | Tensor = [0.0, 0.0], scale: float | Tensor = 1.0, shear: list[float] | Tensor = [0.0, 0.0], interpolation: str = 'nearest', fill: float = 0.0, center: list[float] | None = None, brightness_factor: float | Tensor = 1.0, contrast_factor: float | Tensor = 1.0, saturation_factor: float | Tensor = 1.0, grayscale: bool | Tensor = False, mean: tuple[float, float, float] | None = None, std: tuple[float, float, float] | None = None) -> torch.Tensor

Fused augmentation: ALL operations in ONE kernel.

Combines geometric (crop + flip + affine) and pixel (color + normalize) operations in a single GPU kernel, providing maximum performance.

Automatically selects the most efficient kernel mode: 1. Simple Mode (Crop + Flip only): Used when all affine params are identity. Fastest path, uses integer indexing. 2. Affine Mode (General): Used when any affine param is non-identity. Uses matrix multiplication and interpolation.

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
top int | Tensor

Crop top offset (int or int32 tensor of shape (N,) for per-image)

required
left int | Tensor

Crop left offset (int or int32 tensor of shape (N,) for per-image)

required
height int

Crop height

required
width int

Crop width

required
flip_horizontal bool | Tensor

Whether to flip horizontally (bool or uint8 tensor of shape (N,) for per-image, default: False)

False
angle float | Tensor

Rotation angle in degrees (default: 0.0)

0.0
translate list[float] | Tensor

Translation [dx, dy] (default: [0.0, 0.0])

[0.0, 0.0]
scale float | Tensor

Scale factor (default: 1.0)

1.0
shear list[float] | Tensor

Shear angles [sx, sy] (default: [0.0, 0.0])

[0.0, 0.0]
interpolation str

"nearest" or "bilinear" (default: "nearest")

'nearest'
fill float

Fill value for out-of-bounds pixels (default: 0.0)

0.0
center list[float] | None

Center of rotation (default: image center)

None
brightness_factor float | Tensor

Brightness multiplier (float or tensor of shape (N,) for per-image, 1.0 = no change)

1.0
contrast_factor float | Tensor

Contrast multiplier (float or tensor of shape (N,) for per-image, 1.0 = no change) [FAST mode]

1.0
saturation_factor float | Tensor

Saturation multiplier (float or tensor of shape (N,) for per-image, 1.0 = no change)

1.0
grayscale bool | Tensor

Whether to convert to grayscale (bool or uint8 tensor of shape (N,) for per-image, default: False)

False
mean tuple[float, float, float] | None

Normalization mean parameters (None = skip normalization)

None
std tuple[float, float, float] | None

Normalization std parameters (None = skip normalization)

None

Returns:

Type Description
Tensor

Transformed tensor of shape (N, C, height, width)

Example
img = torch.rand(4, 3, 224, 224, device='cuda')

# Simple mode (Crop + Flip + Color)
result = fused_augment(
    img, top=0, left=0, height=224, width=224,
    flip_horizontal=True,
    brightness_factor=1.2
)

# Affine mode (Rotation + Crop + Color)
result = fused_augment(
    img, top=0, left=0, height=224, width=224,
    angle=30.0,  # Triggers affine mode
    brightness_factor=1.2
)

triton_augment.functional.adjust_brightness

adjust_brightness(image: Tensor, brightness_factor: float) -> torch.Tensor

Adjust brightness of an image (MULTIPLICATIVE operation).

Matches torchvision.transforms.v2.functional.adjust_brightness exactly. Reference: torchvision/transforms/v2/functional/_color.py line 114-125

Formula: output = input * brightness_factor

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
brightness_factor float

How much to adjust the brightness. Must be non-negative. 0 gives a black image, 1 gives the original image, 2 increases brightness by 2x.

required

Returns:

Type Description
Tensor

Brightness-adjusted tensor of the same shape and dtype

Raises:

Type Description
ValueError

If brightness_factor is negative

Example
img = torch.rand(1, 3, 224, 224, device='cuda')
bright_img = adjust_brightness(img, brightness_factor=1.2)  # 20% brighter
dark_img = adjust_brightness(img, brightness_factor=0.8)   # 20% darker

triton_augment.functional.adjust_contrast

adjust_contrast(image: Tensor, contrast_factor: float) -> torch.Tensor

Adjust contrast of an image.

Matches torchvision.transforms.v2.functional.adjust_contrast exactly. Reference: torchvision/transforms/v2/functional/_color.py line 190-206

output = blend(image, grayscale_mean, contrast_factor)

= image * contrast_factor + grayscale_mean * (1 - contrast_factor)

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
contrast_factor float

How much to adjust the contrast. Must be non-negative. 0 gives a gray image, 1 gives the original image, values > 1 increase contrast.

required

Returns:

Type Description
Tensor

Contrast-adjusted tensor of the same shape and dtype

Raises:

Type Description
ValueError

If contrast_factor is negative

Example
img = torch.rand(1, 3, 224, 224, device='cuda')
high_contrast = adjust_contrast(img, contrast_factor=1.5)
low_contrast = adjust_contrast(img, contrast_factor=0.5)

triton_augment.functional.adjust_contrast_fast

adjust_contrast_fast(image: Tensor, contrast_factor: float) -> torch.Tensor

Adjust contrast of an image using FAST centered scaling.

This is faster than adjust_contrast() because it doesn't require computing the grayscale mean. Uses formula: output = (input - 0.5) * contrast_factor + 0.5

NOTE: This is NOT equivalent to torchvision's adjust_contrast, but provides similar perceptual results and is fully fusible with other operations.

Use this when: - You want maximum performance with fusion - Exact torchvision reproduction is not critical

Use adjust_contrast() when: - You need exact torchvision compatibility - Reproducibility with torchvision is required

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
contrast_factor float

How much to adjust the contrast. Must be non-negative. 0.5 decreases contrast, 1.0 gives original, >1.0 increases.

required

Returns:

Type Description
Tensor

Contrast-adjusted tensor of the same shape and dtype

Raises:

Type Description
ValueError

If contrast_factor is negative

Example
img = torch.rand(1, 3, 224, 224, device='cuda')
high_contrast = adjust_contrast_fast(img, contrast_factor=1.5)
# Use in fused operation for maximum speed
result = fused_color_normalize(img, contrast_factor=1.5, ...)

triton_augment.functional.adjust_saturation

adjust_saturation(image: Tensor, saturation_factor: float) -> torch.Tensor

Adjust color saturation of an image.

Matches torchvision.transforms.v2.functional.adjust_saturation exactly.

Formula
output = blend(image, grayscale, saturation_factor)
       = image * saturation_factor + grayscale * (1 - saturation_factor)

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
saturation_factor float

How much to adjust the saturation. Must be non-negative. 0 will give a grayscale image, 1 will give the original image, values > 1 increase saturation.

required

Returns:

Type Description
Tensor

Saturation-adjusted tensor of the same shape and dtype

Raises:

Type Description
ValueError

If saturation_factor is negative

Example
img = torch.rand(1, 3, 224, 224, device='cuda')
grayscale = adjust_saturation(img, saturation_factor=0.0)
saturated = adjust_saturation(img, saturation_factor=2.0)

triton_augment.functional.normalize

normalize(image: Tensor, mean: tuple[float, float, float], std: tuple[float, float, float]) -> torch.Tensor

Normalize a tensor image with mean and standard deviation.

This function normalizes each channel

output[c] = (input[c] - mean[c]) / std[c]

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
mean tuple[float, float, float]

Tuple of mean values for each channel (R, G, B)

required
std tuple[float, float, float]

Tuple of standard deviation values for each channel (R, G, B)

required

Returns:

Type Description
Tensor

Normalized tensor of the same shape and dtype

Example
img = torch.rand(1, 3, 224, 224, device='cuda')
normalized = normalize(img,
                      mean=(0.485, 0.456, 0.406),
                      std=(0.229, 0.224, 0.225))

triton_augment.functional.rgb_to_grayscale

rgb_to_grayscale(image: Tensor, num_output_channels: int = 1, grayscale_mask: Tensor | None = None) -> torch.Tensor

Convert RGB image to grayscale with optional per-image masking.

Matches torchvision.transforms.v2.functional.rgb_to_grayscale exactly. Uses weights: 0.2989R + 0.587G + 0.114*B

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, 3, H, W) on CUDA

required
num_output_channels int

Number of output channels (1 or 3) If 3, grayscale is replicated across channels

1
grayscale_mask Tensor | None

Optional per-image mask [N] (uint8: 0=keep original, 1=convert to gray) If None, converts all images

None

Returns:

Type Description
Tensor

Grayscale tensor of shape (N, num_output_channels, H, W)

Raises:

Type Description
ValueError

If num_output_channels not in {1, 3} or if input not RGB

Example
img = torch.rand(4, 3, 224, 224, device='cuda')
# Convert all images
gray = rgb_to_grayscale(img, num_output_channels=3)
# Convert only some images (per-image mask)
mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool, device='cuda')
gray = rgb_to_grayscale(img, num_output_channels=3, grayscale_mask=mask)

triton_augment.functional.crop

crop(image: Tensor, top: int | Tensor, left: int | Tensor, height: int, width: int) -> torch.Tensor

Crop a rectangular region from the input image, with optional per-image crop positions.

Matches torchvision.transforms.v2.functional.crop exactly when using scalar top/left. Reference: torchvision/transforms/v2/functional/_geometry.py line 1787

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
top int | Tensor

Top pixel coordinate for cropping (int or int32 tensor of shape (N,) for per-image)

required
left int | Tensor

Left pixel coordinate for cropping (int or int32 tensor of shape (N,) for per-image)

required
height int

Height of the cropped image

required
width int

Width of the cropped image

required

Returns:

Type Description
Tensor

Cropped tensor of shape (N, C, height, width)

Example
img = torch.rand(2, 3, 224, 224, device='cuda')
# Crop all images at same position
cropped = crop(img, top=56, left=56, height=112, width=112)
# Crop each image at different position
tops = torch.tensor([56, 100], device='cuda', dtype=torch.int32)
lefts = torch.tensor([56, 80], device='cuda', dtype=torch.int32)
cropped = crop(img, top=tops, left=lefts, height=112, width=112)
Note

For MVP, this requires valid crop coordinates (no padding). Future versions will support padding for out-of-bounds crops.


triton_augment.functional.center_crop

center_crop(image: Tensor, output_size: tuple[int, int] | int) -> torch.Tensor

Crop the center of the image to the given size.

Matches torchvision.transforms.v2.functional.center_crop exactly. Reference: torchvision/transforms/v2/functional/_geometry.py line 2545

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
output_size tuple[int, int] | int

Desired output size (height, width) or int for square crop

required

Returns:

Type Description
Tensor

Center-cropped tensor of shape (N, C, output_size[0], output_size[1])

Raises:

Type Description
ValueError

If output_size is larger than image size

Example
img = torch.rand(2, 3, 224, 224, device='cuda')
# Center crop to 112x112
cropped = center_crop(img, (112, 112))
# or for square crop
cropped = center_crop(img, 112)

triton_augment.functional.affine

affine(image: Tensor, angle: float | Tensor, translate: list[float] | Tensor, scale: float | Tensor, shear: list[float] | Tensor, interpolation=InterpolationMode.NEAREST, fill: float | Sequence[float] | None = 0.0, center: list[float] | None = None) -> torch.Tensor

Apply affine transformation to the image.

Matches torchvision.transforms.v2.functional.affine API. Reference: torchvision/transforms/v2/functional/_geometry.py

Parameters:

Name Type Description Default
image Tensor

Input image tensor [N, C, H, W]. Must be on CUDA device.

required
angle float | Tensor

Rotation angle in degrees, counter-clockwise. Can be a scalar (applied to all images) or tensor of shape [N] for per-image angles.

required
translate list[float] | Tensor

Translation as [dx, dy] in pixels. Positive dx moves right, positive dy moves down. Can be a list or tensor of shape [N, 2].

required
scale float | Tensor

Scale factor. Values > 1 zoom in, < 1 zoom out. Can be a scalar or tensor of shape [N].

required
shear list[float] | Tensor

Shear angles [shear_x, shear_y] in degrees. Can be a list or tensor of shape [N, 2].

required
interpolation

Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster but may have slight differences vs torchvision at pixel boundaries. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother results.

NEAREST
fill float | Sequence[float] | None

Fill value for pixels outside the image boundaries. Default: 0.0

0.0
center list[float] | None

Center of rotation [x, y] in pixel coordinates. Origin is the upper left corner. Default is the center of the image.

None

Returns:

Type Description
Tensor

Transformed image tensor [N, C, H, W]

Note

For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries (e.g., coordinates landing at x.5). This is due to different rounding conventions in floating-point arithmetic. Bilinear interpolation does not have this limitation.

Example
img = torch.rand(2, 3, 224, 224, device='cuda')

# Rotate 45 degrees, translate, scale
result = affine(img, angle=45, translate=[10, 20], scale=1.2, shear=[0, 0])

# Per-image rotation with bilinear interpolation
angles = torch.tensor([30.0, 60.0], device='cuda')
result = affine(img, angle=angles, translate=[0, 0], scale=1.0, shear=[0, 0],
                interpolation=InterpolationMode.BILINEAR)

triton_augment.functional.rotate

rotate(image: Tensor, angle: float | Tensor, interpolation=InterpolationMode.NEAREST, expand: bool = False, center: list[float] | None = None, fill: float | Sequence[float] | None = 0.0) -> torch.Tensor

Rotate the image by angle.

Matches torchvision.transforms.v2.functional.rotate API.

Parameters:

Name Type Description Default
image Tensor

Input image tensor [N, C, H, W]. Must be on CUDA device.

required
angle float | Tensor

Rotation angle in degrees, clockwise. Can be a scalar (applied to all images) or tensor of shape [N] for per-image angles.

required
interpolation

Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother.

NEAREST
expand bool

If True, expands the output to hold the entire rotated image. Currently not supported (raises NotImplementedError).

False
center list[float] | None

Center of rotation [x, y] in pixel coordinates. Origin is the upper left corner. Default is the center of the image.

None
fill float | Sequence[float] | None

Fill value for pixels outside the image boundaries. Default: 0.0

0.0

Returns:

Type Description
Tensor

Rotated image tensor [N, C, H, W]

Note

For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries. See affine() for details.

Example
img = torch.rand(4, 3, 224, 224, device='cuda')

# Rotate all images by 45 degrees
result = rotate(img, angle=45)

# Per-image rotation with bilinear interpolation
angles = torch.tensor([0, 90, 180, 270], device='cuda', dtype=torch.float32)
result = rotate(img, angle=angles, interpolation=InterpolationMode.BILINEAR)

triton_augment.functional.horizontal_flip

horizontal_flip(image: Tensor, flip_mask: Tensor | None = None) -> torch.Tensor

Flip the image horizontally (left to right), with optional per-image control.

Matches torchvision.transforms.v2.functional.horizontal_flip exactly when flip_mask=None. Reference: torchvision/transforms/v2/functional/_geometry.py line 56

Parameters:

Name Type Description Default
image Tensor

Input image tensor of shape (N, C, H, W) on CUDA

required
flip_mask Tensor | None

Optional uint8 tensor of shape (N,) indicating which images to flip (0=no flip, 1=flip). If None, flips all images (default behavior).

None

Returns:

Type Description
Tensor

Horizontally flipped tensor of the same shape

Example
img = torch.rand(2, 3, 224, 224, device='cuda')
# Flip all images
flipped = horizontal_flip(img)
# Flip only first image
flip_mask = torch.tensor([1, 0], device='cuda', dtype=torch.bool)
flipped = horizontal_flip(img, flip_mask)
Note

This uses a custom Triton kernel. For standalone flip operations, PyTorch's tensor.flip(-1) is highly optimized and may be comparable. The main benefit is when fusing with crop (see fused_crop_flip).


Utility Functions

triton_augment.enable_autotune

enable_autotune()

Enable kernel auto-tuning for optimal performance.

When enabled, Triton will test multiple kernel configurations and cache the best one for your GPU and image sizes.

Example
import triton_augment as ta
ta.enable_autotune()
# Now kernels will auto-tune on first use

triton_augment.disable_autotune

disable_autotune()

Disable kernel auto-tuning and use fixed defaults.

When disabled, kernels use fixed configurations that work well across most GPUs and image sizes without tuning overhead.

Example
import triton_augment as ta
ta.disable_autotune()
# Now kernels will use fixed defaults (faster, good performance)

triton_augment.is_autotune_enabled

is_autotune_enabled() -> bool

Check if auto-tuning is currently enabled.

Returns:

Name Type Description
bool bool

True if auto-tuning is enabled, False otherwise


triton_augment.warmup_cache

warmup_cache(batch_sizes: Tuple[int, ...] = (32, 64), image_sizes: Tuple[int, ...] = (224, 256, 512), verbose: bool = True)

Pre-populate the auto-tuning cache for common image sizes.

This function runs the fused kernel with various common configurations to trigger auto-tuning and cache the optimal settings. This eliminates the 5-10 second delay on first use.

Parameters:

Name Type Description Default
batch_sizes Tuple[int, ...]

Tuple of batch sizes to warm up (default: (32, 64))

(32, 64)
image_sizes Tuple[int, ...]

Tuple of square image sizes to warm up (default: (224, 256, 512))

(224, 256, 512)
verbose bool

Whether to print progress messages (default: True)

True
Example
import triton_augment as ta
# Warm up cache for common training scenarios
ta.warmup_cache()
# Custom sizes for your specific use case
ta.warmup_cache(batch_sizes=(16, 128), image_sizes=(128, 384))

Input Requirements

Transform Classes

Transform classes (e.g., TritonFusedAugment, TritonColorJitter, etc.) accept:

  • Device: CUDA (GPU) or CPU - CPU tensors are automatically moved to GPU
  • Shape: (C, H, W), (N, C, H, W), or (N, T, C, H, W) - 3D, 4D, or 5D (video)
  • Dtype: float32 or float16
  • Range: [0, 1] for color operations (required)

Notes:

  • 3D tensors (C, H, W) are automatically converted to (1, C, H, W) internally for processing
  • 5D tensors (N, T, C, H, W) are supported for video augmentation (batch, frames, channels, height, width)

  • For 5D inputs, use same_on_frame=True (default) for consistent augmentation across frames, or same_on_frame=False for independent per-frame augmentation

  • After normalization, values can be outside [0, 1] range

  • CPU tensors are automatically transferred to CUDA for GPU processing

Functional API

Functional functions (e.g., fused_augment(), crop(), normalize(), etc.) expect:

  • Device: CUDA (GPU) - must be on CUDA device
  • Shape: (N, C, H, W) - 4D tensors only
  • Dtype: float32 or float16
  • Range: [0, 1] for color operations (required)

Note: Transform classes handle 3D/5D normalization internally. If using the functional API directly, ensure inputs are already in 4D format (N, C, H, W).


Performance Tips

1. Use Fused Kernel Even for Partial Operations

Key insight: Even if you only need a subset of operations, use TritonFusedAugment or F.fused_augment for best performance! Simply set unused operations to no-op values:

# Example: Only need crop + normalize (no flip, no color jitter)
transform = ta.TritonFusedAugment(
    crop_size=224,
    horizontal_flip_p=0.0,      # No flip
    brightness=0.0,             # No brightness
    contrast=0.0,               # No contrast
    saturation=0.0,             # No saturation
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225)
)
# Still faster than calling crop() + normalize() separately!

The fused kernel is optimized to skip operations set to no-op values at compile time.

2. Auto-Tuning

Enable auto-tuning for optimal performance on your specific GPU and data sizes:

import triton_augment as ta

ta.enable_autotune()  # Enable once at start of training

# Optional: Pre-compile kernels for your data sizes
ta.warmup_cache(batch_sizes=(32, 64), image_sizes=(224, 512))

See Auto-Tuning Guide for detailed configuration.


Additional Resources