diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7f2f24fd69..93fdab7604 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1233,3 +1233,92 @@ def check_weights(): with autocast(enabled=with_quantization, recipe=quantization_recipe): y = module(x, **kwargs) check_weights() + + +@pytest.mark.parametrize( + "module_name", + ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear"), +) +@pytest.mark.parametrize( + "quantization", + ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"), +) +def test_quantizer_columnwise_usage_after_eval( + module_name: str, + quantization: str, +) -> None: + """ + Eval mode removes the columnwise usage from the quantizer, so modules + need to reset columnwise mode each time in backward instead of using it + from quantizer state, which may be stale. + """ + + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + sequence_length = 32 + hidden_size = 32 + + if quantization == "fp8_delayed_scaling": + quantization_recipe = recipe.DelayedScaling() + elif quantization == "fp8_current_scaling": + quantization_recipe = recipe.Float8CurrentScaling() + else: + quantization_recipe = recipe.MXFP8BlockScaling() + + with quantized_model_init(enabled=True, recipe=quantization_recipe): + if module_name == "Linear": + module = Linear(hidden_size, hidden_size, bias=False) + elif module_name == "LayerNormLinear": + module = LayerNormLinear(hidden_size, hidden_size, bias=False) + elif module_name == "LayerNormMLP": + module = LayerNormMLP(hidden_size, hidden_size, bias=False) + elif module_name == "GroupedLinear": + module = GroupedLinear(1, hidden_size, hidden_size, bias=False) + else: + raise AssertionError(f"Unhandled module_name {module_name}") + module = module.cuda() + + def get_weight_quantizers(): + """Return the per-weight ``_quantizer`` objects whose state matters.""" + if module_name == "LayerNormMLP": + return [module.fc1_weight._quantizer, module.fc2_weight._quantizer] + if module_name == "GroupedLinear": + return [module.weight0._quantizer] + return [module.weight._quantizer] + + def run_forward(is_eval: bool): + x = torch.randn(sequence_length, hidden_size, device="cuda", requires_grad=not is_eval) + kwargs = {} + if module_name == "GroupedLinear": + kwargs["m_splits"] = [sequence_length] + ctx = torch.no_grad() if is_eval else torch.enable_grad() + with ctx, autocast(enabled=True, recipe=quantization_recipe): + y = module(x, **kwargs) + if not is_eval: + y.sum().backward() + + # 1. Training forward -- should set columnwise=True. + run_forward(is_eval=False) + for q in get_weight_quantizers(): + assert ( + q.columnwise_usage + ), "After an initial training forward, weight quantizer should have columnwise_usage=True" + + # 2. Eval forward -- should set columnwise=False on primary FP8 weight + # quantizers, simulating the start of an evaluation loop. + run_forward(is_eval=True) + for q in get_weight_quantizers(): + assert ( + not q.columnwise_usage + ), "After an eval forward, weight quantizer should have columnwise_usage=False" + + # 3. Training forward again without eval. + run_forward(is_eval=False) + for q in get_weight_quantizers(): + assert q.columnwise_usage, ( + "After resuming training following an eval forward, the weight " + "quantizer must have columnwise_usage=True." + ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 720a274119..4693af2228 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -135,16 +135,11 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) - # No need to set the quantizer states if weight is already quantized - # for debug mode we create quantizer every iteration, thus we need to set the quantizer states - if weight_quantizers[0] is not None and ( - not isinstance(weights[0], QuantizedTensorStorage) or debug - ): + if weight_quantizers[0] is not None: + if isinstance(weights[0], QuantizedTensorStorage) and not debug: + weight_quantizers = [weight._quantizer for weight in weights] for weight_quantizer in weight_quantizers: weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - elif isinstance(weights[0], QuantizedTensorStorage): - # If weights are already quantized, no need to set quantizer states - weight_quantizers = [weight._quantizer for weight in weights] if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d69e643c4c..5481d5a7fd 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -308,10 +308,7 @@ def forward( # for debug mode we create quantizer every iteration, thus we need to set the quantizer states if is_weight_param_quantized and not debug: weight_quantizer = weight._quantizer - elif weight_quantizer is not None: - # FSDP2: Skip columnwise/transpose creation during forward - # to avoid accumulating caches across layers. Backward's - # FSDP2 all-gather will recreate them. (Issue #2681) + if weight_quantizer is not None: weight_quantizer.set_usage( rowwise=True, columnwise=is_grad_enabled and not is_fsdp2 and backward_override is None, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4fa7eb2856..61de10fa68 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -492,7 +492,7 @@ def _forward( # for debug mode we create quantizer every iteration, thus we need to set the quantizer states if isinstance(fc1_weight, QuantizedTensorStorage) and not debug: fc1_weight_quantizer = fc1_weight._quantizer - elif fc1_weight_quantizer is not None: + if fc1_weight_quantizer is not None: fc1_weight_quantizer.set_usage( rowwise=True, columnwise=is_grad_enabled and not fsdp2_skip_columnwise, @@ -500,7 +500,7 @@ def _forward( if isinstance(fc2_weight, QuantizedTensorStorage) and not debug: fc2_weight_quantizer = fc2_weight._quantizer - elif fc2_weight_quantizer is not None: + if fc2_weight_quantizer is not None: fc2_weight_quantizer.set_usage( rowwise=True, columnwise=is_grad_enabled and not fsdp2_skip_columnwise, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7498760af5..9a827c2002 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -261,9 +261,9 @@ def _linear_forward_impl( weightmat = weight if fp8 or debug: # Configure quantizer - # No need to set the quantizer states if weight is already quantized - # for debug mode we create quantizer every iteration, thus we need to set the quantizer states - if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): + if weight_quantizer is not None: + if isinstance(weight, QuantizedTensor) and not debug: + weight_quantizer = weight._quantizer columnwise_usage = is_grad_enabled and inp.requires_grad and not is_fsdp2 if backward_override is not None: columnwise_usage = False @@ -273,8 +273,6 @@ def _linear_forward_impl( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - elif isinstance(weight, QuantizedTensor): - weight_quantizer = weight._quantizer # Get quantized weight update_ws = is_first_microbatch is None or is_first_microbatch weightmat, new_weight_workspace = quantize_weight( diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index cad31e2c50..e6a19aea38 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -214,6 +214,10 @@ def fuser_forward( if fc1_op.weight.quantizer is not None: fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) fc1_op.weight.quantizer = fc1_weight_quantizer + if fc1_op.weight.quantized_tensors is not None: + for qt in fc1_op.weight.quantized_tensors: + if getattr(qt, "_quantizer", None) is not None: + qt._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) grouped_fc1_weight = fc1_op.weight else: if fc1_op.weight.rowwise_data is None: @@ -234,6 +238,8 @@ def fuser_forward( quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) quantized_fc1_weights.append(quantizer(weight)) else: + if getattr(weight, "_quantizer", None) is not None: + weight._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) quantized_fc1_weights.append(weight) grouped_fc1_weight = quantized_fc1_weights @@ -246,6 +252,10 @@ def fuser_forward( if fc2_op.weight.quantizer is not None: fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) fc2_op.weight.quantizer = fc2_weight_quantizer + if fc2_op.weight.quantized_tensors is not None: + for qt in fc2_op.weight.quantized_tensors: + if getattr(qt, "_quantizer", None) is not None: + qt._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) grouped_fc2_weight = fc2_op.weight else: if fc2_op.weight.rowwise_data is None: @@ -267,6 +277,8 @@ def fuser_forward( quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) quantized_fc2_weights.append(quantizer(weight)) else: + if getattr(weight, "_quantizer", None) is not None: + weight._quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) quantized_fc2_weights.append(weight) grouped_fc2_weight = quantized_fc2_weights diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..81e20b481d 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -653,14 +653,14 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # is needed based on whether it's a forward or backward pass. # If not resharded, the same all-gathered weights are reused in backward, # so both usages may be needed. + training_state = param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD if reshard_after_forward: - training_state = param_group._training_state - is_backward_pass = training_state == TrainingState.PRE_BACKWARD rowwise_usage = not is_backward_pass columnwise_usage = is_backward_pass else: rowwise_usage = True - columnwise_usage = self._quantizer.columnwise_usage + columnwise_usage = is_backward_pass or torch.is_grad_enabled() # For 2D block scaling (128x128 blocks), columnwise data and scales are # the transpose of rowwise data and scales. Only all-gather the rowwise diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed6091c85b..e1a40ca772 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -885,16 +885,16 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # If not resharded after forward pass, the same weights allgathered in forward # are used again in backward and so we dont change the quantizer usages which might need # both rowwise and columnwise usages. + training_state = param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD if reshard_after_forward: - training_state = param_group._training_state - is_backward_pass = training_state == TrainingState.PRE_BACKWARD # In case of hopper/L40, only one of data/transpose is needed # based on forward or backward pass. So setting the quantizer usages appropriately. rowwise_usage = not is_backward_pass columnwise_usage = is_backward_pass else: rowwise_usage = True - columnwise_usage = self._quantizer.columnwise_usage + columnwise_usage = is_backward_pass or torch.is_grad_enabled() sharded_tensors = (self._data,) metadata = (self._scale_inv, rowwise_usage, columnwise_usage, self._fp8_dtype) return sharded_tensors, metadata diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..e68ae81614 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -667,9 +667,9 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # If not resharded after forward pass, the same weights allgathered in forward # are used again in backward. And hence if we need the columnwise data/scale_inv, # we need to send them as well for allgather in forward pass itself. + training_state = param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD if reshard_after_forward: - training_state = param_group._training_state - is_backward_pass = training_state == TrainingState.PRE_BACKWARD # Allgather only the necessary tensors based on forward/backward pass rowwise_usage = not is_backward_pass columnwise_usage = is_backward_pass @@ -681,9 +681,16 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m else: # rowwise usage is always needed for forward pass. rowwise_usage = True + columnwise_usage = is_backward_pass or torch.is_grad_enabled() sharded_tensors = (self._rowwise_data, rowwise_scale_inv) - columnwise_usage = self._quantizer.columnwise_usage if columnwise_usage: + if self._columnwise_data is None or columnwise_scale_inv is None: + raise RuntimeError( + "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " + "for the upcoming backward pass, but the local shard has none. " + "Ensure the weight is quantized with columnwise_usage=True before " + "this all-gather." + ) # If weights are not resharded after forward, then both # rowwise and columnwise data/scale_inv need to be allgathered. sharded_tensors += (self._columnwise_data, columnwise_scale_inv)