Skip to content

Implement train test split#37

Open
SarahAlidoost wants to merge 13 commits intomainfrom
train_test
Open

Implement train test split#37
SarahAlidoost wants to merge 13 commits intomainfrom
train_test

Conversation

@SarahAlidoost
Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost commented Apr 13, 2026

closes #28

This PR:

  • implements train/test/validation
  • add stride to dataset to increase the number of samples (see the explanation 1 below)
  • handles over-fitting by changing torch.optim.Adam to torch.optim.AdamW, and exposing dropout in the model, and using validation set in training (see the explanation 2 below)
  • fix calculating mean/std based on target and not input
  • runs the notebook for residuals

explanations:

  1. Our model has a lot of parameters (see default arguments of the model), so just sampling the whole globe doesn’t really give us enough training data. This can lead to high loss on test and validation sets. One approach is using overlapping tiles to create more samples, like they did in "2.3 Data augmentation and pre-processing" in MAESSTRO paper. That's why I added a stride option to the dataset. I also decrease the number of parameters especially embed_dim in the model in example notebook. This is something to fix later when building a proper training workflow on larger data on HPC.

  2. Another issue was over-fitting. I used validation set during the training, similar to what they did in MAESSTRO code. but they actually used different years for training and validation (like 2012 vs 2011). Since I work with a small dataset and using stride to create more samples, there’s some overlap between train, test, and validation. This is something to fix later when building a proper training workflow on larger data on HPC.

@SarahAlidoost SarahAlidoost marked this pull request as ready for review April 20, 2026 13:06
Copy link
Copy Markdown
Collaborator

@rogerkuou rogerkuou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice implementation @SarahAlidoost !

I only have two minor comments. Just see if they are useful. Feel free to merge!

Comment thread climanet/train.py Outdated
Comment thread climanet/utils.py
model_path,
)
if verbose:
print(f"Model saved to {model_path}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I just found is that the print function will not lively export status to slurm log file.

Shall we replace the print functions with logging?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add to this, executing the Python script with -u does help, but still the logging option seems to be a more structural solution since it gives more info

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is good that print statement is not in slurm log file. The print statement is mainly for example notebook. On HPC, the verbose variable should be False. Instead, we implemented proper logging using torch.utils.tensorboard in #34.

Co-authored-by: Ou Ku <o.ku@esciencecenter.nl>
Copy link
Copy Markdown
Collaborator

@meiertgrootes meiertgrootes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, as Ou said. A very nice implementation. I have no further comments on the code at this point.
On the topic of data augmentation by using overlapping patches, I agree it seems to be a good means of getting more training data. However, this comes with a price, in particular when using a masking based strategy. This is a significant issue for MAE, but even with our physical based masks this is relevant. By creating overlapping patches including for regions where we may have data for one year, but not for another, we increase the probability that the models focuses on/learns local interpolations more at the expense of general representations. This needs to be balanced, so augmentation is fine, but should be used sparingly.

@SarahAlidoost
Copy link
Copy Markdown
Member Author

@rogerkuou and @meiertgrootes Thanks for the reviews. I will wait for #39 to be merged first, because there are some conflict. Then I fix the conflict in this PR.

We need to implement a train/test/validation strategy and hyper-parameter tuning. I made issue #40, please share your ideas.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Introduce train validation test split

3 participants