diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index fd8b3c84..3b0e4091 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -44,20 +44,24 @@ def launch_training_task( model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) initialize_deepspeed_gradient_checkpointing(accelerator) + progress_bar_total = (len(dataloader) + accelerator.gradient_accumulation_steps - 1) // accelerator.gradient_accumulation_steps for epoch_id in range(num_epochs): - for data in tqdm(dataloader): - with accelerator.accumulate(model): - if dataset.load_from_cache: - loss = model({}, inputs=data) - else: - loss = model(data) - accelerator.backward(loss) - if enable_model_cpu_offload: - offload_manager.after_backward() - optimizer.step() - scheduler.step() - optimizer.zero_grad() - model_logger.on_step_end(accelerator, model, save_steps, loss=loss) + with tqdm(total=progress_bar_total) as progress_bar: + for data in dataloader: + with accelerator.accumulate(model): + if dataset.load_from_cache: + loss = model({}, inputs=data) + else: + loss = model(data) + accelerator.backward(loss) + if enable_model_cpu_offload: + offload_manager.after_backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + if accelerator.sync_gradients: + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) + progress_bar.update(1) if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id)