-
Notifications
You must be signed in to change notification settings - Fork 713
[PyTorch] Fix stale columnwise data usage #2925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,16 +135,11 @@ def forward( | |
| is_fp8_activation_recompute_enabled() | ||
| and not in_fp8_activation_recompute_phase() | ||
| ) | ||
| # No need to set the quantizer states if weight is already quantized | ||
| # for debug mode we create quantizer every iteration, thus we need to set the quantizer states | ||
| if weight_quantizers[0] is not None and ( | ||
| not isinstance(weights[0], QuantizedTensorStorage) or debug | ||
| ): | ||
| if weight_quantizers[0] is not None: | ||
| if isinstance(weights[0], QuantizedTensorStorage) and not debug: | ||
| weight_quantizers = [weight._quantizer for weight in weights] | ||
| for weight_quantizer in weight_quantizers: | ||
|
Comment on lines
+139
to
141
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think these changes are needed for this file and any of the other files in the PR. If weight is already quantized, it doesnt make sense to change its internal quantizer and have the quantized weight and its internal quantizer in a state of conflict with each other. In general in case of quantized_model_init, if we are changing quantized_tensor's internal quantizer, quantized_tensor should also be updated to have that appropriate usages. |
||
| weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) | ||
| elif isinstance(weights[0], QuantizedTensorStorage): | ||
| # If weights are already quantized, no need to set quantizer states | ||
| weight_quantizers = [weight._quantizer for weight in weights] | ||
| if output_quantizers[0] is not None: | ||
| for output_quantizer in output_quantizers: | ||
| output_quantizer.set_usage(rowwise=True, columnwise=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,9 +261,9 @@ def _linear_forward_impl( | |
| weightmat = weight | ||
| if fp8 or debug: | ||
| # Configure quantizer | ||
| # No need to set the quantizer states if weight is already quantized | ||
| # for debug mode we create quantizer every iteration, thus we need to set the quantizer states | ||
| if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): | ||
| if weight_quantizer is not None: | ||
| if isinstance(weight, QuantizedTensor) and not debug: | ||
| weight_quantizer = weight._quantizer | ||
|
Comment on lines
+264
to
+266
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The original The new code only re-assigns if weight_quantizer is not None:
if isinstance(weight, QuantizedTensor) and not debug:
weight_quantizer = weight._quantizer
columnwise_usage = ...
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
# weight_quantizer is None but weight is pre-quantized — pick up its quantizer
weight_quantizer = weight._quantizer |
||
| columnwise_usage = is_grad_enabled and inp.requires_grad and not is_fsdp2 | ||
| if backward_override is not None: | ||
| columnwise_usage = False | ||
|
|
@@ -273,8 +273,6 @@ def _linear_forward_impl( | |
| and not in_fp8_activation_recompute_phase() | ||
| ) | ||
| weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) | ||
| elif isinstance(weight, QuantizedTensor): | ||
| weight_quantizer = weight._quantizer | ||
| # Get quantized weight | ||
| update_ws = is_first_microbatch is None or is_first_microbatch | ||
| weightmat, new_weight_workspace = quantize_weight( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -653,14 +653,14 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m | |
| # is needed based on whether it's a forward or backward pass. | ||
| # If not resharded, the same all-gathered weights are reused in backward, | ||
| # so both usages may be needed. | ||
| training_state = param_group._training_state | ||
| is_backward_pass = training_state == TrainingState.PRE_BACKWARD | ||
| if reshard_after_forward: | ||
| training_state = param_group._training_state | ||
| is_backward_pass = training_state == TrainingState.PRE_BACKWARD | ||
| rowwise_usage = not is_backward_pass | ||
| columnwise_usage = is_backward_pass | ||
| else: | ||
| rowwise_usage = True | ||
| columnwise_usage = self._quantizer.columnwise_usage | ||
| columnwise_usage = is_backward_pass or torch.is_grad_enabled() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it still makes sense to use self._quantize.columwise_usage as the real truth of what data is What we are doing here is that we are silently creating columnwise data after allgather for allgathered tensor, even though original sharded data tensor didnt have that data. In my opinion, I am against any change here since even doing such a validation and throwing error is going to incur CPU overheads when using
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment in every other FSDP2 related changes |
||
|
|
||
| # For 2D block scaling (128x128 blocks), columnwise data and scales are | ||
| # the transpose of rowwise data and scales. Only all-gather the rowwise | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -667,9 +667,9 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m | |||||||||||||||||||||||||||||||||||||||
| # If not resharded after forward pass, the same weights allgathered in forward | ||||||||||||||||||||||||||||||||||||||||
| # are used again in backward. And hence if we need the columnwise data/scale_inv, | ||||||||||||||||||||||||||||||||||||||||
| # we need to send them as well for allgather in forward pass itself. | ||||||||||||||||||||||||||||||||||||||||
| training_state = param_group._training_state | ||||||||||||||||||||||||||||||||||||||||
| is_backward_pass = training_state == TrainingState.PRE_BACKWARD | ||||||||||||||||||||||||||||||||||||||||
| if reshard_after_forward: | ||||||||||||||||||||||||||||||||||||||||
| training_state = param_group._training_state | ||||||||||||||||||||||||||||||||||||||||
| is_backward_pass = training_state == TrainingState.PRE_BACKWARD | ||||||||||||||||||||||||||||||||||||||||
| # Allgather only the necessary tensors based on forward/backward pass | ||||||||||||||||||||||||||||||||||||||||
| rowwise_usage = not is_backward_pass | ||||||||||||||||||||||||||||||||||||||||
| columnwise_usage = is_backward_pass | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -681,9 +681,16 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m | |||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| # rowwise usage is always needed for forward pass. | ||||||||||||||||||||||||||||||||||||||||
| rowwise_usage = True | ||||||||||||||||||||||||||||||||||||||||
| columnwise_usage = is_backward_pass or torch.is_grad_enabled() | ||||||||||||||||||||||||||||||||||||||||
| sharded_tensors = (self._rowwise_data, rowwise_scale_inv) | ||||||||||||||||||||||||||||||||||||||||
| columnwise_usage = self._quantizer.columnwise_usage | ||||||||||||||||||||||||||||||||||||||||
| if columnwise_usage: | ||||||||||||||||||||||||||||||||||||||||
| if self._columnwise_data is None or columnwise_scale_inv is None: | ||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | ||||||||||||||||||||||||||||||||||||||||
| "for the upcoming backward pass, but the local shard has none. " | ||||||||||||||||||||||||||||||||||||||||
| "Ensure the weight is quantized with columnwise_usage=True before " | ||||||||||||||||||||||||||||||||||||||||
| "this all-gather." | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+684
to
+693
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Users who ran eval with grads enabled previously got silently incorrect (stale) data; they now get a hard crash. While the crash is more correct, the error message could guide them:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| # If weights are not resharded after forward, then both | ||||||||||||||||||||||||||||||||||||||||
| # rowwise and columnwise data/scale_inv need to be allgathered. | ||||||||||||||||||||||||||||||||||||||||
| sharded_tensors += (self._columnwise_data, columnwise_scale_inv) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test doesnt make sense to me. I dont think we should be toggling the quantizer usages in case of quantized_model_init at all.
This breaks the very principle that quantized tensor and its internal quantizer shouldnt be in conflict with each other. And in here, the columnwise_data is present for the quantized_tensor even though the columnwise_usage is set to False.