-
Notifications
You must be signed in to change notification settings - Fork 718
[JAX] Fix MNIST L2 jax test instability #2933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
121e6aa
7208790
eb9a3cd
10ca394
a14db00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| # See LICENSE for license information. | ||
| """MNIST training on single GPU""" | ||
| import argparse | ||
| import math | ||
| import unittest | ||
| from functools import partial | ||
| import sys | ||
|
|
@@ -223,6 +224,10 @@ def train_and_evaluate(args): | |
| print("PASSED") | ||
| return None | ||
|
|
||
| train_losses = [] | ||
| train_accuracies = [] | ||
| test_losses = [] | ||
| test_accuracies = [] | ||
| for epoch in range(1, args.epochs + 1): | ||
| rng, input_rng = jax.random.split(rng) | ||
| rng, dropout_rng = jax.random.split(rng) | ||
|
|
@@ -233,6 +238,11 @@ def train_and_evaluate(args): | |
| ) | ||
| test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) | ||
|
|
||
| train_losses.append(train_loss) | ||
| train_accuracies.append(train_accuracy) | ||
| test_losses.append(test_loss) | ||
| test_accuracies.append(test_accuracy) | ||
|
|
||
| print( | ||
| f"Epoch: {epoch:>2} " | ||
| f"Train Loss: {train_loss:.6f} " | ||
|
|
@@ -241,7 +251,7 @@ def train_and_evaluate(args): | |
| f"Test Accuracy: {test_accuracy:.6f} " | ||
| ) | ||
|
|
||
| return [train_loss, train_accuracy, test_loss, test_accuracy] | ||
| return [train_losses, train_accuracies, test_losses, test_accuracies] | ||
|
|
||
|
|
||
| def mnist_parser(args): | ||
|
|
@@ -324,15 +334,42 @@ def setUpClass(cls): | |
|
|
||
| @staticmethod | ||
| def verify(actual): | ||
| """Check If loss and accuracy match target""" | ||
| desired_traing_loss = 0.055 | ||
| """Check that loss and accuracy match target. | ||
|
|
||
| ``actual`` is ``[train_losses, train_accuracies, test_losses, test_accuracies]``, | ||
| i.e. per-epoch lists of metrics. To avoid flakiness from stochastic noise in | ||
| the final epoch near convergence (especially under FP8), the check considers | ||
| a tail window of the last ~10% of epochs (at least 2) and asserts on the | ||
| best metric within that window. | ||
| """ | ||
| train_losses, train_accuracies, test_losses, test_accuracies = actual | ||
| epochs = len(train_losses) | ||
| tail = max(2, math.ceil(epochs * 0.1)) | ||
| tail = min(tail, epochs) | ||
|
Comment on lines
+347
to
+348
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
With |
||
|
|
||
| best_train_loss = min(train_losses[-tail:]) | ||
| best_train_accuracy = max(train_accuracies[-tail:]) | ||
| best_test_loss = min(test_losses[-tail:]) | ||
| best_test_accuracy = max(test_accuracies[-tail:]) | ||
|
|
||
| desired_traing_loss = 0.06 | ||
| desired_traing_accuracy = 0.98 | ||
| desired_test_loss = 0.045 | ||
| desired_test_accuracy = 0.098 | ||
| assert actual[0] < desired_traing_loss | ||
| assert actual[1] > desired_traing_accuracy | ||
| assert actual[2] < desired_test_loss | ||
| assert actual[3] > desired_test_accuracy | ||
| desired_test_loss = 0.05 | ||
| desired_test_accuracy = 0.98 | ||
| assert ( | ||
| best_train_loss < desired_traing_loss | ||
| ), f"best train loss over last {tail} epochs {best_train_loss} >= {desired_traing_loss}" | ||
| assert best_train_accuracy > desired_traing_accuracy, ( | ||
| f"best train accuracy over last {tail} epochs {best_train_accuracy} " | ||
| f"<= {desired_traing_accuracy}" | ||
| ) | ||
| assert ( | ||
| best_test_loss < desired_test_loss | ||
| ), f"best test loss over last {tail} epochs {best_test_loss} >= {desired_test_loss}" | ||
| assert best_test_accuracy > desired_test_accuracy, ( | ||
| f"best test accuracy over last {tail} epochs {best_test_accuracy} " | ||
| f"<= {desired_test_accuracy}" | ||
| ) | ||
|
|
||
| @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") | ||
| def test_te_bf16(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verify()will raise an unhelpfulTypeErrorif passedNonetrain_and_evaluatereturnsNonewhenargs.dry_runisTrue. If a future test (or a direct call) passesNonetoverify(), the tuple-unpack on line 345 raisesTypeError: cannot unpack non-iterable NoneType objectwith no context. Adding an early guard keeps failure messages readable: