Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,3 +1233,92 @@ def check_weights():
with autocast(enabled=with_quantization, recipe=quantization_recipe):
y = module(x, **kwargs)
check_weights()


@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear"),
)
@pytest.mark.parametrize(
"quantization",
("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_quantizer_columnwise_usage_after_eval(
module_name: str,
quantization: str,
) -> None:
"""
Eval mode removes the columnwise usage from the quantizer, so modules
need to reset columnwise mode each time in backward instead of using it
from quantizer state, which may be stale.
"""

if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

sequence_length = 32
hidden_size = 32

if quantization == "fp8_delayed_scaling":
quantization_recipe = recipe.DelayedScaling()
elif quantization == "fp8_current_scaling":
quantization_recipe = recipe.Float8CurrentScaling()
else:
quantization_recipe = recipe.MXFP8BlockScaling()

with quantized_model_init(enabled=True, recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(hidden_size, hidden_size, bias=False)
elif module_name == "LayerNormLinear":
module = LayerNormLinear(hidden_size, hidden_size, bias=False)
elif module_name == "LayerNormMLP":
module = LayerNormMLP(hidden_size, hidden_size, bias=False)
elif module_name == "GroupedLinear":
module = GroupedLinear(1, hidden_size, hidden_size, bias=False)
else:
raise AssertionError(f"Unhandled module_name {module_name}")
module = module.cuda()

def get_weight_quantizers():
"""Return the per-weight ``_quantizer`` objects whose state matters."""
if module_name == "LayerNormMLP":
return [module.fc1_weight._quantizer, module.fc2_weight._quantizer]
if module_name == "GroupedLinear":
return [module.weight0._quantizer]
return [module.weight._quantizer]

def run_forward(is_eval: bool):
x = torch.randn(sequence_length, hidden_size, device="cuda", requires_grad=not is_eval)
kwargs = {}
if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length]
ctx = torch.no_grad() if is_eval else torch.enable_grad()
with ctx, autocast(enabled=True, recipe=quantization_recipe):
y = module(x, **kwargs)
if not is_eval:
y.sum().backward()

# 1. Training forward -- should set columnwise=True.
run_forward(is_eval=False)
for q in get_weight_quantizers():
assert (
q.columnwise_usage
), "After an initial training forward, weight quantizer should have columnwise_usage=True"

# 2. Eval forward -- should set columnwise=False on primary FP8 weight
# quantizers, simulating the start of an evaluation loop.
run_forward(is_eval=True)
for q in get_weight_quantizers():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test doesnt make sense to me. I dont think we should be toggling the quantizer usages in case of quantized_model_init at all.

This breaks the very principle that quantized tensor and its internal quantizer shouldnt be in conflict with each other. And in here, the columnwise_data is present for the quantized_tensor even though the columnwise_usage is set to False.

assert (
not q.columnwise_usage
), "After an eval forward, weight quantizer should have columnwise_usage=False"

# 3. Training forward again without eval.
run_forward(is_eval=False)
for q in get_weight_quantizers():
assert q.columnwise_usage, (
"After resuming training following an eval forward, the weight "
"quantizer must have columnwise_usage=True."
)
11 changes: 3 additions & 8 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,11 @@ def forward(
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
# No need to set the quantizer states if weight is already quantized
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if weight_quantizers[0] is not None and (
not isinstance(weights[0], QuantizedTensorStorage) or debug
):
if weight_quantizers[0] is not None:
if isinstance(weights[0], QuantizedTensorStorage) and not debug:
weight_quantizers = [weight._quantizer for weight in weights]
for weight_quantizer in weight_quantizers:
Comment on lines +139 to 141
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think these changes are needed for this file and any of the other files in the PR. If weight is already quantized, it doesnt make sense to change its internal quantizer and have the quantized weight and its internal quantizer in a state of conflict with each other.

In general in case of quantized_model_init, if we are changing quantized_tensor's internal quantizer, quantized_tensor should also be updated to have that appropriate usages.

weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weights[0], QuantizedTensorStorage):
# If weights are already quantized, no need to set quantizer states
weight_quantizers = [weight._quantizer for weight in weights]
if output_quantizers[0] is not None:
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
Expand Down
5 changes: 1 addition & 4 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,7 @@ def forward(
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if is_weight_param_quantized and not debug:
weight_quantizer = weight._quantizer
elif weight_quantizer is not None:
# FSDP2: Skip columnwise/transpose creation during forward
# to avoid accumulating caches across layers. Backward's
# FSDP2 all-gather will recreate them. (Issue #2681)
if weight_quantizer is not None:
weight_quantizer.set_usage(
rowwise=True,
columnwise=is_grad_enabled and not is_fsdp2 and backward_override is None,
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,15 +492,15 @@ def _forward(
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if isinstance(fc1_weight, QuantizedTensorStorage) and not debug:
fc1_weight_quantizer = fc1_weight._quantizer
elif fc1_weight_quantizer is not None:
if fc1_weight_quantizer is not None:
fc1_weight_quantizer.set_usage(
rowwise=True,
columnwise=is_grad_enabled and not fsdp2_skip_columnwise,
)

if isinstance(fc2_weight, QuantizedTensorStorage) and not debug:
fc2_weight_quantizer = fc2_weight._quantizer
elif fc2_weight_quantizer is not None:
if fc2_weight_quantizer is not None:
fc2_weight_quantizer.set_usage(
rowwise=True,
columnwise=is_grad_enabled and not fsdp2_skip_columnwise,
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def _linear_forward_impl(
weightmat = weight
if fp8 or debug:
# Configure quantizer
# No need to set the quantizer states if weight is already quantized
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug):
if weight_quantizer is not None:
if isinstance(weight, QuantizedTensor) and not debug:
weight_quantizer = weight._quantizer
Comment on lines +264 to +266
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Dropped defensive weight_quantizer assignment loses the quantize_weight call

The original elif isinstance(weight, QuantizedTensor): weight_quantizer = weight._quantizer handled the case where weight_quantizer arrives as None while weight is already a QuantizedTensor. In that path, quantize_weight immediately dereferences quantizer.rowwise_usage (line 710 of base.py) and will raise AttributeError: 'NoneType' object has no attribute 'rowwise_usage'.

The new code only re-assigns weight_quantizer when it is already non-None, so the previously guarded scenario now crashes instead of falling back to the weight's own quantizer. The missing assignment should be:

if weight_quantizer is not None:
    if isinstance(weight, QuantizedTensor) and not debug:
        weight_quantizer = weight._quantizer
    columnwise_usage = ...
    weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
    # weight_quantizer is None but weight is pre-quantized — pick up its quantizer
    weight_quantizer = weight._quantizer

columnwise_usage = is_grad_enabled and inp.requires_grad and not is_fsdp2
if backward_override is not None:
columnwise_usage = False
Expand All @@ -273,8 +273,6 @@ def _linear_forward_impl(
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
weight_quantizer = weight._quantizer
# Get quantized weight
update_ws = is_first_microbatch is None or is_first_microbatch
weightmat, new_weight_workspace = quantize_weight(
Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def fuser_forward(
if fc1_op.weight.quantizer is not None:
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
fc1_op.weight.quantizer = fc1_weight_quantizer
if fc1_op.weight.quantized_tensors is not None:
for qt in fc1_op.weight.quantized_tensors:
if getattr(qt, "_quantizer", None) is not None:
qt._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
grouped_fc1_weight = fc1_op.weight
else:
if fc1_op.weight.rowwise_data is None:
Expand All @@ -234,6 +238,8 @@ def fuser_forward(
quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
quantized_fc1_weights.append(quantizer(weight))
else:
if getattr(weight, "_quantizer", None) is not None:
weight._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
quantized_fc1_weights.append(weight)
grouped_fc1_weight = quantized_fc1_weights

Expand All @@ -246,6 +252,10 @@ def fuser_forward(
if fc2_op.weight.quantizer is not None:
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
fc2_op.weight.quantizer = fc2_weight_quantizer
if fc2_op.weight.quantized_tensors is not None:
for qt in fc2_op.weight.quantized_tensors:
if getattr(qt, "_quantizer", None) is not None:
qt._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
grouped_fc2_weight = fc2_op.weight
else:
if fc2_op.weight.rowwise_data is None:
Expand All @@ -267,6 +277,8 @@ def fuser_forward(
quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
quantized_fc2_weights.append(quantizer(weight))
else:
if getattr(weight, "_quantizer", None) is not None:
weight._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
quantized_fc2_weights.append(weight)
grouped_fc2_weight = quantized_fc2_weights

Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,14 +653,14 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
# is needed based on whether it's a forward or backward pass.
# If not resharded, the same all-gathered weights are reused in backward,
# so both usages may be needed.
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
if reshard_after_forward:
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
else:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage
columnwise_usage = is_backward_pass or torch.is_grad_enabled()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it still makes sense to use self._quantize.columwise_usage as the real truth of what data is
"really avaliable" in the sharded quantized tensor and throw an error if that usage doesnt match
is_backward_pass or torch.is_grad_enabled(Similar to mxfp8 tensor)

What we are doing here is that we are silently creating columnwise data after allgather for allgathered tensor, even though original sharded data tensor didnt have that data.

In my opinion, I am against any change here since even doing such a validation and throwing error is going to incur CPU overheads when using
torch.is_grad_enabled

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment in every other FSDP2 related changes


# For 2D block scaling (128x128 blocks), columnwise data and scales are
# the transpose of rowwise data and scales. Only all-gather the rowwise
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,16 +885,16 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward and so we dont change the quantizer usages which might need
# both rowwise and columnwise usages.
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
if reshard_after_forward:
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
else:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage
columnwise_usage = is_backward_pass or torch.is_grad_enabled()
sharded_tensors = (self._data,)
metadata = (self._scale_inv, rowwise_usage, columnwise_usage, self._fp8_dtype)
return sharded_tensors, metadata
Expand Down
13 changes: 10 additions & 3 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
if reshard_after_forward:
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
Expand All @@ -681,9 +681,16 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
else:
# rowwise usage is always needed for forward pass.
rowwise_usage = True
columnwise_usage = is_backward_pass or torch.is_grad_enabled()
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
if self._columnwise_data is None or columnwise_scale_inv is None:
raise RuntimeError(
"FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data "
"for the upcoming backward pass, but the local shard has none. "
"Ensure the weight is quantized with columnwise_usage=True before "
"this all-gather."
)
Comment on lines +684 to +693
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 torch.is_grad_enabled() can be True during eval without torch.no_grad()

model.eval() alone does not disable the gradient tape — torch.is_grad_enabled() stays True unless the caller wraps the eval loop with torch.no_grad(). In that situation columnwise_usage becomes True, but the local shard may not have _columnwise_data (it was never quantized with columnwise support during eval), so the new RuntimeError fires.

Users who ran eval with grads enabled previously got silently incorrect (stale) data; they now get a hard crash. While the crash is more correct, the error message could guide them:

Suggested change
columnwise_usage = is_backward_pass or torch.is_grad_enabled()
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
if self._columnwise_data is None or columnwise_scale_inv is None:
raise RuntimeError(
"FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data "
"for the upcoming backward pass, but the local shard has none. "
"Ensure the weight is quantized with columnwise_usage=True before "
"this all-gather."
)
if self._columnwise_data is None or columnwise_scale_inv is None:
raise RuntimeError(
"FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data "
"for the upcoming backward pass, but the local shard has none. "
"Ensure the weight is quantized with columnwise_usage=True before "
"this all-gather. If you are running evaluation without requiring "
"gradients, wrap the eval loop with torch.no_grad()."
)

# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
Expand Down
Loading