API Reference¶
Complete auto-generated API reference for all Triton-Augment operations.
Transform Classes¶
Transform classes provide stateful, random augmentations similar to torchvision. Recommended for training pipelines.
triton_augment.TritonFusedAugment ¶
Bases: Module
Fused augmentation: All operations in ONE kernel.
This transform combines ALL augmentations in a single GPU kernel launch. Unified Fusion: Combines Affine (Rotation/Translation/Scale/Shear) + Crop + Flip + Color Jitter + Grayscale + Normalize in order into a single kernel launch.
Performance: up to 14x faster than torchvision.transforms.Compose!
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
crop_size
|
int | tuple[int, int] | None
|
Desired output size (int or tuple). If int, output is square (crop_size, crop_size). If None, output size equals input size (no cropping). Default: None. |
None
|
horizontal_flip_p
|
float
|
Probability of horizontal flip (default: 0.0, no flip) |
0.0
|
degrees
|
float | tuple[float, float]
|
Rotation degrees range. If float, range is (-degrees, +degrees). If tuple, range is (degrees[0], degrees[1]). Default: 0 (no rotation). |
0
|
translate
|
tuple[float, float] | None
|
Translation range as fraction of image size (tx, ty). E.g., (0.1, 0.1) means translate up to 10% of width/height. Default: None (no translation). |
None
|
scale
|
tuple[float, float] | None
|
Scale range (min, max). E.g., (0.8, 1.2) scales between 80% and 120%. Default: None (no scaling). |
None
|
shear
|
float | tuple[float, float] | None
|
Shear range in degrees. If float, range is (-shear, +shear) for x-axis. If tuple of 2, (shear[0], shear[1]) for x-axis. If tuple of 4, (shear[0], shear[1]) for x-axis and (shear[2], shear[3]) for y-axis. Default: None (no shearing). |
None
|
interpolation
|
Interpolation mode for affine (InterpolationMode.NEAREST or BILINEAR). Only used in affine mode. Default: NEAREST. |
NEAREST
|
|
fill
|
float
|
Fill value for out-of-bounds pixels in affine mode. Default: 0.0. |
0.0
|
brightness
|
float | tuple[float, float]
|
How much to jitter brightness. If float, chosen uniformly from [max(0, 1-brightness), 1+brightness]. If tuple, chosen uniformly from [brightness[0], brightness[1]]. |
0
|
contrast
|
float | tuple[float, float]
|
How much to jitter contrast (same format as brightness) |
0
|
saturation
|
float | tuple[float, float]
|
How much to jitter saturation (same format as brightness) |
0
|
grayscale_p
|
float
|
Probability of converting to grayscale (default: 0.0, no grayscale) |
0.0
|
mean
|
Optional[tuple[float, float, float]]
|
Sequence of means for R, G, B channels. If None, normalization is skipped. |
None
|
std
|
Optional[tuple[float, float, float]]
|
Sequence of stds for R, G, B channels. If None, normalization is skipped. |
None
|
same_on_batch
|
bool
|
If True, all images in batch (N dimension) share the same random parameters. If False (default), each image gets different random parameters. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (T dimension) share the same random parameters. If False, each frame gets different random parameters. Only applies to 5D input [N, T, C, H, W]. Default: True (consistent augmentation across frames). |
True
|
Example
# Simple mode (crop + flip + color)
transform = ta.TritonFusedAugment(
crop_size=112,
horizontal_flip_p=0.5,
brightness=0.2,
contrast=0.2,
saturation=0.2,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
# Affine mode (rotation + scale + color)
transform = ta.TritonFusedAugment(
crop_size=224,
degrees=30, # Enables affine mode
scale=(0.8, 1.2),
horizontal_flip_p=0.5,
brightness=0.2,
interpolation=InterpolationMode.BILINEAR
)
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img) # Single kernel launch!
Note
- Uses FAST contrast (centered scaling), not torchvision's blend-with-mean
- Input must be (C, H, W), (N, C, H, W), or (N, T, C, H, W) float tensor on CUDA in [0, 1] range
Functions¶
forward ¶
Apply all augmentations in a single fused kernel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W) Can be on CPU or CUDA (will be moved to CUDA automatically) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Augmented tensor of same shape and device as input |
triton_augment.TritonColorJitterNormalize ¶
Bases: Module
Combined color jitter, random grayscale, and normalization in a single fused operation.
This class combines TritonColorJitter, TritonRandomGrayscale, and TritonNormalize into a single operation that uses a fused kernel for maximum performance. This is the recommended way to apply color augmentations and normalization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
brightness
|
Optional[Union[float, Sequence[float]]]
|
How much to jitter brightness (same as TritonColorJitter) |
None
|
contrast
|
Optional[Union[float, Sequence[float]]]
|
How much to jitter contrast (same as TritonColorJitter) |
None
|
saturation
|
Optional[Union[float, Sequence[float]]]
|
How much to jitter saturation (same as TritonColorJitter) |
None
|
grayscale_p
|
float
|
Probability of converting to grayscale (default: 0.0) |
0.0
|
mean
|
Optional[Tuple[float, float, float]]
|
Sequence of means for normalization (R, G, B). If None, normalization is skipped. |
None
|
std
|
Optional[Tuple[float, float, float]]
|
Sequence of standard deviations for normalization (R, G, B). If None, normalization is skipped. |
None
|
same_on_batch
|
bool
|
If True, all images in batch share the same random parameters If False (default), each image in batch gets different random parameters |
False
|
Example
# Full augmentation pipeline in one transform (per-image randomness)
transform = TritonColorJitterNormalize(
brightness=0.2, # Range: [0.8, 1.2]
contrast=0.2, # Range: [0.8, 1.2]
saturation=0.2, # Range: [0.8, 1.2]
grayscale_p=0.1, # 10% chance of grayscale (per-image)
mean=(0.485, 0.456, 0.406), # ImageNet normalization (optional)
std=(0.229, 0.224, 0.225), # ImageNet normalization (optional)
same_on_batch=False
)
img = torch.rand(4, 3, 224, 224, device='cuda')
augmented = transform(img) # Each image gets different augmentation
# Without normalization (mean=None, std=None by default)
transform_no_norm = TritonColorJitterNormalize(
brightness=0.2, contrast=0.2, saturation=0.2
)
Functions¶
forward ¶
Apply random color jitter and normalization in a single fused operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
img
|
Tensor
|
Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W) Can be on CPU or CUDA (will be moved to CUDA automatically) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Augmented and normalized tensor of same shape and device as input |
triton_augment.TritonRandomCropFlip ¶
Bases: Module
Fused random crop + random horizontal flip.
This class combines random crop and random horizontal flip in a SINGLE kernel launch, eliminating intermediate memory transfers.
Performance: ~1.5-2x faster than applying TritonRandomCrop + TritonRandomHorizontalFlip sequentially.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size
|
Union[int, Sequence[int]]
|
Desired output size (height, width) or int for square crop |
required |
horizontal_flip_p
|
float
|
Probability of horizontal flip (default: 0.5) |
0.5
|
same_on_batch
|
bool
|
If True, all images in batch (N dimension) share the same random parameters. If False (default), each image gets different random parameters. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (T dimension) share the same random parameters. If False, each frame gets different random parameters. Only applies to 5D input [N, T, C, H, W]. Default: True. |
True
|
Example
# Fused version (FAST - single kernel, per-image randomness)
transform_fused = TritonRandomCropFlip(112, horizontal_flip_p=0.5, same_on_batch=False)
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform_fused(img) # Each image gets different crop & flip
# Equivalent sequential version (SLOWER - 2 kernels)
transform_seq = nn.Sequential(
TritonRandomCrop(112, same_on_batch=False),
TritonRandomHorizontalFlip(p=0.5, same_on_batch=False)
)
result_seq = transform_seq(img)
Note
The fused version uses compile-time branching (tl.constexpr), so there's zero overhead when flip is not triggered.
triton_augment.TritonColorJitter ¶
Bases: Module
Randomly change the brightness, contrast, and saturation of an image.
This is a GPU-accelerated version of torchvision.transforms.v2.ColorJitter that uses a fused kernel for maximum performance.
IMPORTANT: Contrast uses FAST mode (centered scaling: (pixel - 0.5) * factor + 0.5),
NOT torchvision's blend-with-mean approach. This is much faster and provides similar visual results.
If you need exact torchvision behavior, use the individual functional APIs:
- F.adjust_brightness() (exact)
- F.adjust_contrast() (torchvision-exact, slower)
- F.adjust_saturation() (exact)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
brightness
|
Optional[Union[float, Sequence[float]]]
|
How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers. |
None
|
contrast
|
Optional[Union[float, Sequence[float]]]
|
How much to jitter contrast (uses FAST mode). contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. Should be non-negative numbers. |
None
|
saturation
|
Optional[Union[float, Sequence[float]]]
|
How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. Should be non negative numbers. |
None
|
same_on_batch
|
bool
|
If True, all images in batch (N dimension) share the same random parameters. If False (default), each image gets different random parameters. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (T dimension) share the same random parameters. If False, each frame gets different random parameters. Only applies to 5D input [N, T, C, H, W]. Default: True (consistent augmentation across frames). |
True
|
Example
# Basic usage with per-image randomness
transform = TritonColorJitter(
brightness=0.2, # Range: [0.8, 1.2]
contrast=0.2, # Range: [0.8, 1.2] (FAST contrast)
saturation=0.2, # Range: [0.8, 1.2]
same_on_batch=False
)
img = torch.rand(4, 3, 224, 224, device='cuda')
augmented = transform(img) # Each image gets different augmentation
# Custom ranges
transform = TritonColorJitter(
brightness=(0.5, 1.5), # Custom range
contrast=(0.7, 1.3), # Custom range (FAST mode)
saturation=(0.0, 2.0) # Custom range
)
Performance
- Uses fused kernel for all operations in a single pass
- Faster than sequential operations
- For even more speed, combine with normalization using TritonColorJitterNormalize
triton_augment.TritonNormalize ¶
Bases: Module
Normalize a tensor image with mean and standard deviation.
This is a GPU-accelerated version of torchvision.transforms.Normalize that uses a Triton kernel for improved performance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Tuple[float, float, float]
|
Sequence of means for each channel (R, G, B) |
required |
std
|
Tuple[float, float, float]
|
Sequence of standard deviations for each channel (R, G, B) |
required |
Example
Functions¶
forward ¶
Normalize the input image tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
img
|
Tensor
|
Input tensor of shape (C, H, W), (N, C, H, W), or (N, T, C, H, W) Can be on CPU or CUDA (will be moved to CUDA automatically) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Normalized tensor of same shape and device as input |
triton_augment.TritonRandomGrayscale ¶
Bases: Module
Randomly convert image to grayscale with probability p.
Matches torchvision.transforms.v2.RandomGrayscale behavior with optional per-image randomness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p
|
float
|
Probability of converting to grayscale (default: 0.1) |
0.1
|
num_output_channels
|
int
|
Number of output channels (1 or 3, default: 3) Usually 3 to maintain compatibility with RGB pipelines |
3
|
same_on_batch
|
bool
|
If True, all images in batch (N dimension) make the same grayscale decision. If False (default), each image independently decides grayscale conversion. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (T dimension) make the same grayscale decision. If False, each frame independently decides. Only applies to 5D input [N, T, C, H, W]. Default: True. |
True
|
Example
# Per-image randomness (each image independently converted)
transform = TritonRandomGrayscale(p=0.5, num_output_channels=3, same_on_batch=False)
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img) # Each image has 50% chance of being grayscale
# Batch-wide (all images converted or none)
transform = TritonRandomGrayscale(p=0.5, num_output_channels=3, same_on_batch=True)
result = transform(img) # Either all 4 images are grayscale or none are
triton_augment.TritonGrayscale ¶
Bases: Module
Convert image to grayscale.
Matches torchvision.transforms.v2.Grayscale behavior. Uses weights: 0.2989R + 0.587G + 0.114*B
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_output_channels
|
int
|
Number of output channels (1 or 3). If 1, output is single-channel grayscale. If 3, grayscale is replicated to 3 channels. |
1
|
Example
triton_augment.TritonRandomCrop ¶
Bases: Module
Crop a random portion of the image.
Matches torchvision.transforms.v2.RandomCrop behavior (simplified MVP version).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size
|
Union[int, Sequence[int]]
|
Desired output size (height, width) or int for square crop |
required |
same_on_batch
|
bool
|
If True, all images in batch (N dimension) crop at the same position. If False (default), each image gets different random crop position. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (T dimension) crop at the same position. If False, each frame gets different random crop. Only applies to 5D input [N, T, C, H, W]. Default: True. |
True
|
Example
torch.Size([4, 3, 112, 112])Note
For MVP, padding is not supported. Image must be larger than crop size. Future versions will support padding, pad_if_needed, fill, padding_mode.
triton_augment.TritonCenterCrop ¶
Bases: Module
Crop the center of the image.
Matches torchvision.transforms.v2.CenterCrop behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size
|
Union[int, Sequence[int]]
|
Desired output size (height, width) or int for square crop |
required |
Example
torch.Size([4, 3, 112, 112])triton_augment.TritonRandomAffine ¶
Bases: Module
Random affine transformation of the image keeping center invariant.
GPU-accelerated implementation using Triton kernels. Matches the API of torchvision.transforms.v2.RandomAffine.
Supports
- 3D input: (C, H, W) - single image
- 4D input: (N, C, H, W) - batch of images
- 5D input: (N, T, C, H, W) - batch of videos
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
degrees
|
Union[float, Sequence[float]]
|
Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). Set to 0 to deactivate rotations. |
required |
translate
|
Optional[Tuple[float, float]]
|
Tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. |
None
|
scale
|
Optional[Tuple[float, float]]
|
Scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. Will keep original scale by default. |
None
|
shear
|
Optional[Union[float, Sequence[float]]]
|
Range of degrees to select from. If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) will be applied. Else if shear is a tuple of 2 values, a x-axis shear in (shear[0], shear[1]) will be applied. Else if shear is a tuple of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Default: None. |
None
|
interpolation
|
Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother. |
NEAREST
|
|
fill
|
float
|
Constant fill value for areas outside the transformed image. Default: 0. |
0.0
|
center
|
Optional[Tuple[int, int]]
|
Optional center of rotation (x, y) in pixel coordinates. Origin is the upper left corner. Default is the center of the image. |
None
|
same_on_batch
|
bool
|
If True, all images in batch share the same random parameters. Default: False. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (5D input) share the same random parameters. Default: True. |
True
|
Note
For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries due to floating-point rounding. Bilinear interpolation does not have this limitation.
Example
transform = TritonRandomAffine(
degrees=15,
translate=(0.1, 0.1),
scale=(0.8, 1.2),
shear=10,
interpolation=InterpolationMode.BILINEAR
)
# Apply to batch of images
img = torch.rand(4, 3, 224, 224, device='cuda')
result = transform(img)
# Apply to video with same transform per frame
video = torch.rand(2, 8, 3, 112, 112, device='cuda')
result = transform(video)
triton_augment.TritonRandomRotation ¶
Bases: TritonRandomAffine
Random rotation of the image.
GPU-accelerated implementation using Triton kernels. Matches the API of torchvision.transforms.v2.RandomRotation.
Supports
- 3D input: (C, H, W) - single image
- 4D input: (N, C, H, W) - batch of images
- 5D input: (N, T, C, H, W) - batch of videos
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
degrees
|
Union[float, Sequence[float]]
|
Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). |
required |
interpolation
|
Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother. |
NEAREST
|
|
expand
|
bool
|
If True, expands the output to hold the entire rotated image. Currently not supported (raises NotImplementedError). |
False
|
center
|
Optional[Tuple[int, int]]
|
Optional center of rotation (x, y) in pixel coordinates. Origin is the upper left corner. Default is the center of the image. |
None
|
fill
|
float
|
Constant fill value for areas outside the rotated image. Default: 0. |
0.0
|
same_on_batch
|
bool
|
If True, all images in batch share the same random rotation angle. Default: False. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (5D input) share the same random rotation angle. Default: True. |
True
|
Note
For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries due to floating-point rounding. Bilinear interpolation does not have this limitation.
Example
triton_augment.TritonRandomHorizontalFlip ¶
Bases: Module
Horizontally flip the image randomly with probability p.
Matches torchvision.transforms.v2.RandomHorizontalFlip behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p
|
float
|
Probability of flipping (default: 0.5) |
0.5
|
same_on_batch
|
bool
|
If True, all images in batch (N dimension) share the same flip decision. If False (default), each image gets different random decision. |
False
|
same_on_frame
|
bool
|
If True, all frames in a video (T dimension) share the same flip decision. If False, each frame gets different random decision. Only applies to 5D input [N, T, C, H, W]. Default: True. |
True
|
Example
Functional API¶
Low-level functional interface for fine-grained control with fixed parameters. Use when you need deterministic operations.
triton_augment.functional.fused_augment ¶
fused_augment(image: Tensor, top: int | Tensor, left: int | Tensor, height: int, width: int, flip_horizontal: bool | Tensor = False, angle: float | Tensor = 0.0, translate: list[float] | Tensor = [0.0, 0.0], scale: float | Tensor = 1.0, shear: list[float] | Tensor = [0.0, 0.0], interpolation: str = 'nearest', fill: float = 0.0, center: list[float] | None = None, brightness_factor: float | Tensor = 1.0, contrast_factor: float | Tensor = 1.0, saturation_factor: float | Tensor = 1.0, grayscale: bool | Tensor = False, mean: tuple[float, float, float] | None = None, std: tuple[float, float, float] | None = None) -> torch.Tensor
Fused augmentation: ALL operations in ONE kernel.
Combines geometric (crop + flip + affine) and pixel (color + normalize) operations in a single GPU kernel, providing maximum performance.
Automatically selects the most efficient kernel mode: 1. Simple Mode (Crop + Flip only): Used when all affine params are identity. Fastest path, uses integer indexing. 2. Affine Mode (General): Used when any affine param is non-identity. Uses matrix multiplication and interpolation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
top
|
int | Tensor
|
Crop top offset (int or int32 tensor of shape (N,) for per-image) |
required |
left
|
int | Tensor
|
Crop left offset (int or int32 tensor of shape (N,) for per-image) |
required |
height
|
int
|
Crop height |
required |
width
|
int
|
Crop width |
required |
flip_horizontal
|
bool | Tensor
|
Whether to flip horizontally (bool or uint8 tensor of shape (N,) for per-image, default: False) |
False
|
angle
|
float | Tensor
|
Rotation angle in degrees (default: 0.0) |
0.0
|
translate
|
list[float] | Tensor
|
Translation [dx, dy] (default: [0.0, 0.0]) |
[0.0, 0.0]
|
scale
|
float | Tensor
|
Scale factor (default: 1.0) |
1.0
|
shear
|
list[float] | Tensor
|
Shear angles [sx, sy] (default: [0.0, 0.0]) |
[0.0, 0.0]
|
interpolation
|
str
|
"nearest" or "bilinear" (default: "nearest") |
'nearest'
|
fill
|
float
|
Fill value for out-of-bounds pixels (default: 0.0) |
0.0
|
center
|
list[float] | None
|
Center of rotation (default: image center) |
None
|
brightness_factor
|
float | Tensor
|
Brightness multiplier (float or tensor of shape (N,) for per-image, 1.0 = no change) |
1.0
|
contrast_factor
|
float | Tensor
|
Contrast multiplier (float or tensor of shape (N,) for per-image, 1.0 = no change) [FAST mode] |
1.0
|
saturation_factor
|
float | Tensor
|
Saturation multiplier (float or tensor of shape (N,) for per-image, 1.0 = no change) |
1.0
|
grayscale
|
bool | Tensor
|
Whether to convert to grayscale (bool or uint8 tensor of shape (N,) for per-image, default: False) |
False
|
mean
|
tuple[float, float, float] | None
|
Normalization mean parameters (None = skip normalization) |
None
|
std
|
tuple[float, float, float] | None
|
Normalization std parameters (None = skip normalization) |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Transformed tensor of shape (N, C, height, width) |
Example
img = torch.rand(4, 3, 224, 224, device='cuda')
# Simple mode (Crop + Flip + Color)
result = fused_augment(
img, top=0, left=0, height=224, width=224,
flip_horizontal=True,
brightness_factor=1.2
)
# Affine mode (Rotation + Crop + Color)
result = fused_augment(
img, top=0, left=0, height=224, width=224,
angle=30.0, # Triggers affine mode
brightness_factor=1.2
)
triton_augment.functional.adjust_brightness ¶
Adjust brightness of an image (MULTIPLICATIVE operation).
Matches torchvision.transforms.v2.functional.adjust_brightness exactly. Reference: torchvision/transforms/v2/functional/_color.py line 114-125
Formula: output = input * brightness_factor
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
brightness_factor
|
float
|
How much to adjust the brightness. Must be non-negative. 0 gives a black image, 1 gives the original image, 2 increases brightness by 2x. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Brightness-adjusted tensor of the same shape and dtype |
Raises:
| Type | Description |
|---|---|
ValueError
|
If brightness_factor is negative |
triton_augment.functional.adjust_contrast ¶
Adjust contrast of an image.
Matches torchvision.transforms.v2.functional.adjust_contrast exactly. Reference: torchvision/transforms/v2/functional/_color.py line 190-206
output = blend(image, grayscale_mean, contrast_factor)
= image * contrast_factor + grayscale_mean * (1 - contrast_factor)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
contrast_factor
|
float
|
How much to adjust the contrast. Must be non-negative. 0 gives a gray image, 1 gives the original image, values > 1 increase contrast. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Contrast-adjusted tensor of the same shape and dtype |
Raises:
| Type | Description |
|---|---|
ValueError
|
If contrast_factor is negative |
triton_augment.functional.adjust_contrast_fast ¶
Adjust contrast of an image using FAST centered scaling.
This is faster than adjust_contrast() because it doesn't require computing the grayscale mean. Uses formula: output = (input - 0.5) * contrast_factor + 0.5
NOTE: This is NOT equivalent to torchvision's adjust_contrast, but provides similar perceptual results and is fully fusible with other operations.
Use this when: - You want maximum performance with fusion - Exact torchvision reproduction is not critical
Use adjust_contrast() when: - You need exact torchvision compatibility - Reproducibility with torchvision is required
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
contrast_factor
|
float
|
How much to adjust the contrast. Must be non-negative. 0.5 decreases contrast, 1.0 gives original, >1.0 increases. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Contrast-adjusted tensor of the same shape and dtype |
Raises:
| Type | Description |
|---|---|
ValueError
|
If contrast_factor is negative |
triton_augment.functional.adjust_saturation ¶
Adjust color saturation of an image.
Matches torchvision.transforms.v2.functional.adjust_saturation exactly.
Formula
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
saturation_factor
|
float
|
How much to adjust the saturation. Must be non-negative. 0 will give a grayscale image, 1 will give the original image, values > 1 increase saturation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Saturation-adjusted tensor of the same shape and dtype |
Raises:
| Type | Description |
|---|---|
ValueError
|
If saturation_factor is negative |
triton_augment.functional.normalize ¶
normalize(image: Tensor, mean: tuple[float, float, float], std: tuple[float, float, float]) -> torch.Tensor
Normalize a tensor image with mean and standard deviation.
This function normalizes each channel
output[c] = (input[c] - mean[c]) / std[c]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
mean
|
tuple[float, float, float]
|
Tuple of mean values for each channel (R, G, B) |
required |
std
|
tuple[float, float, float]
|
Tuple of standard deviation values for each channel (R, G, B) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Normalized tensor of the same shape and dtype |
triton_augment.functional.rgb_to_grayscale ¶
rgb_to_grayscale(image: Tensor, num_output_channels: int = 1, grayscale_mask: Tensor | None = None) -> torch.Tensor
Convert RGB image to grayscale with optional per-image masking.
Matches torchvision.transforms.v2.functional.rgb_to_grayscale exactly. Uses weights: 0.2989R + 0.587G + 0.114*B
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, 3, H, W) on CUDA |
required |
num_output_channels
|
int
|
Number of output channels (1 or 3) If 3, grayscale is replicated across channels |
1
|
grayscale_mask
|
Tensor | None
|
Optional per-image mask [N] (uint8: 0=keep original, 1=convert to gray) If None, converts all images |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Grayscale tensor of shape (N, num_output_channels, H, W) |
Raises:
| Type | Description |
|---|---|
ValueError
|
If num_output_channels not in {1, 3} or if input not RGB |
Example
img = torch.rand(4, 3, 224, 224, device='cuda')
# Convert all images
gray = rgb_to_grayscale(img, num_output_channels=3)
# Convert only some images (per-image mask)
mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool, device='cuda')
gray = rgb_to_grayscale(img, num_output_channels=3, grayscale_mask=mask)
triton_augment.functional.crop ¶
Crop a rectangular region from the input image, with optional per-image crop positions.
Matches torchvision.transforms.v2.functional.crop exactly when using scalar top/left. Reference: torchvision/transforms/v2/functional/_geometry.py line 1787
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
top
|
int | Tensor
|
Top pixel coordinate for cropping (int or int32 tensor of shape (N,) for per-image) |
required |
left
|
int | Tensor
|
Left pixel coordinate for cropping (int or int32 tensor of shape (N,) for per-image) |
required |
height
|
int
|
Height of the cropped image |
required |
width
|
int
|
Width of the cropped image |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Cropped tensor of shape (N, C, height, width) |
Example
img = torch.rand(2, 3, 224, 224, device='cuda')
# Crop all images at same position
cropped = crop(img, top=56, left=56, height=112, width=112)
# Crop each image at different position
tops = torch.tensor([56, 100], device='cuda', dtype=torch.int32)
lefts = torch.tensor([56, 80], device='cuda', dtype=torch.int32)
cropped = crop(img, top=tops, left=lefts, height=112, width=112)
Note
For MVP, this requires valid crop coordinates (no padding). Future versions will support padding for out-of-bounds crops.
triton_augment.functional.center_crop ¶
Crop the center of the image to the given size.
Matches torchvision.transforms.v2.functional.center_crop exactly. Reference: torchvision/transforms/v2/functional/_geometry.py line 2545
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
output_size
|
tuple[int, int] | int
|
Desired output size (height, width) or int for square crop |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Center-cropped tensor of shape (N, C, output_size[0], output_size[1]) |
Raises:
| Type | Description |
|---|---|
ValueError
|
If output_size is larger than image size |
triton_augment.functional.affine ¶
affine(image: Tensor, angle: float | Tensor, translate: list[float] | Tensor, scale: float | Tensor, shear: list[float] | Tensor, interpolation=InterpolationMode.NEAREST, fill: float | Sequence[float] | None = 0.0, center: list[float] | None = None) -> torch.Tensor
Apply affine transformation to the image.
Matches torchvision.transforms.v2.functional.affine API. Reference: torchvision/transforms/v2/functional/_geometry.py
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor [N, C, H, W]. Must be on CUDA device. |
required |
angle
|
float | Tensor
|
Rotation angle in degrees, counter-clockwise. Can be a scalar (applied to all images) or tensor of shape [N] for per-image angles. |
required |
translate
|
list[float] | Tensor
|
Translation as [dx, dy] in pixels. Positive dx moves right, positive dy moves down. Can be a list or tensor of shape [N, 2]. |
required |
scale
|
float | Tensor
|
Scale factor. Values > 1 zoom in, < 1 zoom out. Can be a scalar or tensor of shape [N]. |
required |
shear
|
list[float] | Tensor
|
Shear angles [shear_x, shear_y] in degrees. Can be a list or tensor of shape [N, 2]. |
required |
interpolation
|
Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster but may have slight differences vs torchvision at pixel boundaries. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother results. |
NEAREST
|
|
fill
|
float | Sequence[float] | None
|
Fill value for pixels outside the image boundaries. Default: 0.0 |
0.0
|
center
|
list[float] | None
|
Center of rotation [x, y] in pixel coordinates. Origin is the upper left corner. Default is the center of the image. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Transformed image tensor [N, C, H, W] |
Note
For nearest neighbor interpolation, there may be minor differences compared to torchvision at exact pixel boundaries (e.g., coordinates landing at x.5). This is due to different rounding conventions in floating-point arithmetic. Bilinear interpolation does not have this limitation.
Example
img = torch.rand(2, 3, 224, 224, device='cuda')
# Rotate 45 degrees, translate, scale
result = affine(img, angle=45, translate=[10, 20], scale=1.2, shear=[0, 0])
# Per-image rotation with bilinear interpolation
angles = torch.tensor([30.0, 60.0], device='cuda')
result = affine(img, angle=angles, translate=[0, 0], scale=1.0, shear=[0, 0],
interpolation=InterpolationMode.BILINEAR)
triton_augment.functional.rotate ¶
rotate(image: Tensor, angle: float | Tensor, interpolation=InterpolationMode.NEAREST, expand: bool = False, center: list[float] | None = None, fill: float | Sequence[float] | None = 0.0) -> torch.Tensor
Rotate the image by angle.
Matches torchvision.transforms.v2.functional.rotate API.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor [N, C, H, W]. Must be on CUDA device. |
required |
angle
|
float | Tensor
|
Rotation angle in degrees, clockwise. Can be a scalar (applied to all images) or tensor of shape [N] for per-image angles. |
required |
interpolation
|
Interpolation mode for sampling. Either: - InterpolationMode.NEAREST (default): Nearest neighbor, faster. - InterpolationMode.BILINEAR: Bilinear interpolation, smoother. |
NEAREST
|
|
expand
|
bool
|
If True, expands the output to hold the entire rotated image. Currently not supported (raises NotImplementedError). |
False
|
center
|
list[float] | None
|
Center of rotation [x, y] in pixel coordinates. Origin is the upper left corner. Default is the center of the image. |
None
|
fill
|
float | Sequence[float] | None
|
Fill value for pixels outside the image boundaries. Default: 0.0 |
0.0
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Rotated image tensor [N, C, H, W] |
Note
For nearest neighbor interpolation, there may be minor differences compared
to torchvision at exact pixel boundaries. See affine() for details.
Example
img = torch.rand(4, 3, 224, 224, device='cuda')
# Rotate all images by 45 degrees
result = rotate(img, angle=45)
# Per-image rotation with bilinear interpolation
angles = torch.tensor([0, 90, 180, 270], device='cuda', dtype=torch.float32)
result = rotate(img, angle=angles, interpolation=InterpolationMode.BILINEAR)
triton_augment.functional.horizontal_flip ¶
Flip the image horizontally (left to right), with optional per-image control.
Matches torchvision.transforms.v2.functional.horizontal_flip exactly when flip_mask=None. Reference: torchvision/transforms/v2/functional/_geometry.py line 56
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image
|
Tensor
|
Input image tensor of shape (N, C, H, W) on CUDA |
required |
flip_mask
|
Tensor | None
|
Optional uint8 tensor of shape (N,) indicating which images to flip (0=no flip, 1=flip). If None, flips all images (default behavior). |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Horizontally flipped tensor of the same shape |
Example
Note
This uses a custom Triton kernel. For standalone flip operations, PyTorch's tensor.flip(-1) is highly optimized and may be comparable. The main benefit is when fusing with crop (see fused_crop_flip).
Utility Functions¶
triton_augment.enable_autotune ¶
triton_augment.disable_autotune ¶
triton_augment.is_autotune_enabled ¶
Check if auto-tuning is currently enabled.
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if auto-tuning is enabled, False otherwise |
triton_augment.warmup_cache ¶
warmup_cache(batch_sizes: Tuple[int, ...] = (32, 64), image_sizes: Tuple[int, ...] = (224, 256, 512), verbose: bool = True)
Pre-populate the auto-tuning cache for common image sizes.
This function runs the fused kernel with various common configurations to trigger auto-tuning and cache the optimal settings. This eliminates the 5-10 second delay on first use.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_sizes
|
Tuple[int, ...]
|
Tuple of batch sizes to warm up (default: (32, 64)) |
(32, 64)
|
image_sizes
|
Tuple[int, ...]
|
Tuple of square image sizes to warm up (default: (224, 256, 512)) |
(224, 256, 512)
|
verbose
|
bool
|
Whether to print progress messages (default: True) |
True
|
Input Requirements¶
Transform Classes¶
Transform classes (e.g., TritonFusedAugment, TritonColorJitter, etc.) accept:
- Device: CUDA (GPU) or CPU - CPU tensors are automatically moved to GPU
- Shape:
(C, H, W),(N, C, H, W), or(N, T, C, H, W)- 3D, 4D, or 5D (video) - Dtype: float32 or float16
- Range: [0, 1] for color operations (required)
Notes:
- 3D tensors
(C, H, W)are automatically converted to(1, C, H, W)internally for processing -
5D tensors
(N, T, C, H, W)are supported for video augmentation (batch, frames, channels, height, width) -
For 5D inputs, use
same_on_frame=True(default) for consistent augmentation across frames, orsame_on_frame=Falsefor independent per-frame augmentation -
After normalization, values can be outside [0, 1] range
-
CPU tensors are automatically transferred to CUDA for GPU processing
Functional API¶
Functional functions (e.g., fused_augment(), crop(), normalize(), etc.) expect:
- Device: CUDA (GPU) - must be on CUDA device
- Shape:
(N, C, H, W)- 4D tensors only - Dtype: float32 or float16
- Range: [0, 1] for color operations (required)
Note: Transform classes handle 3D/5D normalization internally. If using the functional API directly, ensure inputs are already in 4D format (N, C, H, W).
Performance Tips¶
1. Use Fused Kernel Even for Partial Operations¶
Key insight: Even if you only need a subset of operations, use TritonFusedAugment or F.fused_augment for best performance! Simply set unused operations to no-op values:
# Example: Only need crop + normalize (no flip, no color jitter)
transform = ta.TritonFusedAugment(
crop_size=224,
horizontal_flip_p=0.0, # No flip
brightness=0.0, # No brightness
contrast=0.0, # No contrast
saturation=0.0, # No saturation
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
# Still faster than calling crop() + normalize() separately!
The fused kernel is optimized to skip operations set to no-op values at compile time.
2. Auto-Tuning¶
Enable auto-tuning for optimal performance on your specific GPU and data sizes:
import triton_augment as ta
ta.enable_autotune() # Enable once at start of training
# Optional: Pre-compile kernels for your data sizes
ta.warmup_cache(batch_sizes=(32, 64), image_sizes=(224, 512))
See Auto-Tuning Guide for detailed configuration.
Additional Resources¶
- Quick Start Guide: Training integration examples
- Float16 Support: Half-precision performance and memory savings
- Contrast Notes: Differences between fast and torchvision-exact contrast
- Batch Behavior: Understanding
same_on_batchparameter - Benchmark Results: Detailed performance comparisons