-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
152 lines (124 loc) · 4.57 KB
/
train.py
File metadata and controls
152 lines (124 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
from typing import Optional
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
import os
import sys
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.callbacks import TQDMProgressBar
from tqdm import tqdm
from prima.datasets import DataModule
from prima.models.prima import PRIMA
from prima.utils.pylogger import get_pylogger
from prima.utils.misc import log_hyperparameters
import signal
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
class MyTQDMProgressBar(TQDMProgressBar):
def __init__(self):
super(MyTQDMProgressBar, self).__init__()
def init_train_tqdm(self):
bar = super().init_train_tqdm()
bar.ncols = 150
bar.dynamic_ncols=False
return bar
def init_validation_tqdm(self):
bar = tqdm(
desc=self.validation_description,
position=0,
disable=self.is_disabled,
leave=True,
# dynamic_ncols=True,
file=sys.stdout,
dynamic_ncols= False,
ncols = 150,
)
return bar
@hydra.main(version_base="1.2", config_path= "./configs_hydra", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
datamodule = DataModule(cfg)
model = PRIMA(cfg)
# Setup Tensorboard logger
logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='',
default_hp_metric=False)
loggers = [logger]
# Setup checkpoint saving
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'),
# every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS,
every_n_epochs=cfg.GENERAL.CHECKPOINT_EPOCHS,
save_last=True,
# Monitor a metric so `save_top_k` keeps the best checkpoint instead of the last one.
# We monitor the validation loss logged as 'val/loss' (lower is better).
monitor='val/loss',
mode='min',
save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K,
filename="best-{epoch:03d}-{val_loss:.4f}", # Clearly label the best checkpoint
)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
callbacks = [
checkpoint_callback,
lr_monitor,
# rich_callback
MyTQDMProgressBar()
]
log = get_pylogger(__name__)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=loggers,
plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher', None) is not None) else None),
sync_batchnorm=True,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
# Train the model
# Determine checkpoint path
ckpt_path = None
last_v1_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last-v1.ckpt')
last_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last.ckpt')
if os.path.exists(last_v1_ckpt):
ckpt_path = last_v1_ckpt
log.info(f"Resuming from checkpoint: {ckpt_path}")
elif os.path.exists(last_ckpt):
ckpt_path = last_ckpt
log.info(f"Resuming from checkpoint: {ckpt_path}")
else:
log.info("No checkpoint found, starting from scratch")
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info("Fitting done")
if __name__ == "__main__":
import torch
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.memory_allocated(i)/1024**2:.2f} MiB allocated, "
f"{torch.cuda.memory_reserved(i)/1024**2:.2f} MiB reserved")
main()