Skip to content

Commit bfcc98d

Browse files
0.21.6
1 parent a571d86 commit bfcc98d

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.21.5"
10+
version = "0.21.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/seed.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
import random
3+
import torch
4+
import os
5+
6+
7+
def set_all_seeds(seed: int):
8+
"""Set the seed for all relevant random number generators to ensure reproducibility.
9+
This function sets the seed for Python's built-in `random` module, NumPy,
10+
and PyTorch's CPU and GPU (CUDA) random number generators. It also configures
11+
PyTorch's settings to improve the reproducibility of experiments, which is
12+
crucial when debugging or comparing model performances.
13+
14+
Args:
15+
seed (int): The seed value to be set for all random number generators.
16+
17+
Example:
18+
>>> from spotpython.utils.seed import set_all_seeds
19+
>>> set_all_seeds(42)
20+
>>> # Proceed with model initialization or data processing to ensure results can be reproduced
21+
>>> model = SomeModel() # Replace with actual model
22+
>>> train_model(model) # Replace with actual training function
23+
24+
Notes:
25+
- Setting `torch.backends.cudnn.deterministic` to `True` can make computations
26+
more reproducible but at the potential cost of performance.
27+
- Additional considerations may be necessary for complete reproducibility
28+
in distributed or multi-threaded setups.
29+
"""
30+
random.seed(seed)
31+
np.random.seed(seed)
32+
torch.manual_seed(seed)
33+
if torch.cuda.is_available():
34+
torch.cuda.manual_seed_all(seed)
35+
torch.backends.cudnn.deterministic = True # Improvements for reproducibility
36+
torch.backends.cudnn.benchmark = False
37+
38+
os.environ["PYTHONHASHSEED"] = str(seed) # Ensuring hash-based functions use a consistent seed

0 commit comments

Comments
 (0)