"""
This module provides a collection of utility functions and classes primarily designed to support distributed training, logging, and model management within a PyTorch deep learning workflow.
It includes functionalities for:
- Initializing and managing distributed training environments (DDP).
- Controlling print statements for master processes in distributed setups.
- Setting random seeds for reproducibility.
- Gradient norm calculation.
- Saving and loading model checkpoints.
- Converting data formats between PyTorch tensors and NumPy arrays.
- Logging metrics to TensorBoard.
- Tracking and smoothing metric values during training.
- Iterating with progress logging.
- Updating nested dictionaries.
- Cleaning directory walks by excluding specific files/directories.
The module aims to streamline common deep learning operations, especially
in distributed and large-scale training scenarios.
"""
import os
import re
import sys
import builtins
import time
import glob
import datetime
import numpy as np
from collections import defaultdict, deque
import torch
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
from pathlib import Path
from yacs.config import CfgNode as CN
from functools import partial
import collections.abc
import gc
from typing import (
Optional,
Tuple,
List,
Iterator,
)
from numpy.typing import NDArray
import multiprocessing
from torch import inf
from datetime import timedelta
original_print = builtins.print
[docs]
def setup_for_distributed(is_master):
"""
Disable printing for non-master processes in a distributed training setup.
This function replaces the built-in `print` function with a custom one
that only prints output if the current process is the master process (rank 0),
or if `force=True` is passed to the print call. This prevents cluttered
output when running on multiple GPUs/nodes.
Parameters
----------
is_master : bool
True if the current process is the master process (rank 0), False otherwise.
"""
builtin_print = original_print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print("[{}] ".format(now), end="") # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
[docs]
def is_dist_avail_and_initialized():
"""
Check if PyTorch distributed backend is available and initialized.
Returns
-------
bool
True if distributed training is available and initialized, False otherwise.
"""
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
[docs]
def get_world_size():
"""
Return the total number of participating processes in the distributed group.
Returns 0 if distributed mode is not initialized.
Returns
-------
int
The world size.
"""
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
[docs]
def get_rank():
"""
Return the rank of the current process in the distributed group.
Returns 0 if distributed mode is not initialized. The master process typically has rank 0.
Returns
-------
int
The rank of the current process.
"""
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
[docs]
def is_main_process():
"""
Check if the current process is the main (master) process (rank 0).
Returns
-------
bool
True if the current process is the main process, False otherwise.
"""
return get_rank() == 0
[docs]
def init_devices(args, cfg):
"""
Initialize the PyTorch distributed environment and sets up the device for the current process.
This function handles different distributed setup scenarios (e.g., ITP, environment variables, SLURM).
It sets the appropriate GPU device, initializes the process group, and configures
the custom print function for distributed logging.
Parameters
----------
args : Any
An object containing command-line arguments or configuration,
expected to have attributes like `dist_on_itp`, `gpu`, `dist_backend`, `dist_url`.
cfg : YACS CN object
The configuration object, used to determine the default device if CUDA is not available.
Returns
-------
torch.device
The PyTorch device assigned to the current process.
Raises
------
AssertionError
If distributed training is attempted without GPUs when environment variables are set.
"""
args.distributed = False
if args.dist_on_itp:
args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
args.dist_url = "tcp://%s:%s" % (
os.environ["MASTER_ADDR"],
os.environ["MASTER_PORT"],
)
os.environ["LOCAL_RANK"] = str(args.gpu)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
assert torch.cuda.is_available(), "Distributed training without GPUs is not supported!"
env_dict = {
key: os.environ[key]
for key in (
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"LOCAL_RANK",
"WORLD_SIZE",
)
}
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif (
"SLURM_PROCID" in os.environ
and args.gpu is not None
and len(np.unique(np.array(args.gpu.strip().split(",")))) > 1
):
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
setup_for_distributed(is_master=True) # hack
if torch.cuda.is_available() and args.gpu is not None and args.gpu != "":
# We don't neeed 'cuda:gpu_number' because we have set CUDA_VISIBLE_DEVICES,
# so the gpu_number is only used to know if we want to use GPU or not
device = torch.device(f"cuda")
print(f"**** Using GPU: '{device}' for training")
else:
device = torch.device(cfg.SYSTEM.DEVICE) # e.g. "cpu"
print(f"**** Using '{device}' for training")
return device
args.distributed = True
torch.cuda.set_device(args.gpu)
print(
"| distributed init (rank {}): {}, gpu {}".format(args.rank, args.dist_url, args.gpu),
flush=True,
)
if cfg.TEST.BY_CHUNKS.ENABLE and cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.ENABLE:
os.environ["NCCL_BLOCKING_WAIT"] = "0" # not to enforce timeout in nccl backend
timeout_ms = 36000000
else:
timeout_ms = 1800000
if not is_dist_avail_and_initialized():
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=timedelta(seconds=timeout_ms),
)
else:
# If it was initialized means that something may have been running in the past so here
# we are trying to clean all the cache as much as possible
torch.cuda.empty_cache()
gc.collect()
dist.barrier()
setup_for_distributed(args.rank == 0)
if args.rank == 0:
device = torch.device("cuda" if torch.cuda.is_available() else cfg.SYSTEM.DEVICE)
else:
device = torch.device("cuda" if torch.cuda.is_available() else cfg.SYSTEM.DEVICE)
return device
[docs]
def set_seed(seed=42):
"""
Set the random seed for reproducibility across multiple Python modules and PyTorch.
The seed is adjusted by the distributed rank to ensure different random
states for each process in a distributed setup, which can be beneficial
for certain operations (e.g., data loading).
Parameters
----------
seed : int, optional
The base seed value. Defaults to 42.
"""
seed = seed + get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
[docs]
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
"""
Compute the total norm of gradients for a collection of parameters.
This function is typically used for gradient clipping.
Parameters
----------
parameters : Iterable[torch.Tensor] or torch.Tensor
An iterable of model parameters or a single parameter tensor.
norm_type : float, optional
The type of the norm (e.g., 2.0 for L2 norm, `inf` for max norm).
Defaults to 2.0.
Returns
-------
torch.Tensor
The total norm of the gradients. Returns a tensor with value 0.0 if no
parameters have gradients.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.0)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
norm_type,
)
return total_norm
[docs]
def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, model_build_kwargs=None, extension="pth"):
"""
Save the model checkpoint to the specified path.
This function saves the model's state dictionary, optimizer state, current epoch,
configuration, and BiaPy version. It ensures that saving is performed only by
the main process in a distributed setup.
Parameters
----------
cfg : YACS CN object
The configuration object.
biapy_version : str
The current version of BiaPy.
jobname : str
The name of the current job/experiment.
epoch : int
The current epoch number.
model_without_ddp : nn.Module
The model instance, typically the unwrapped model if using DistributedDataParallel.
optimizer : torch.optim.Optimizer
The optimizer's state.
model_build_kwargs : Optional[Dict], optional
Keyword arguments used to build the model, useful for re-instantiating
the model from the checkpoint. Defaults to None.
extension : str, optional
The file extension for the checkpoint file. Options are 'pth' (native PyTorch format)
or 'safetensors' (https://github.com/huggingface/safetensors). Defaults to "pth".
Returns
-------
Path
The path to the saved checkpoint file.
"""
output_dir = Path(cfg.PATHS.CHECKPOINT)
checkpoint_paths = [output_dir / "{}-checkpoint-{}.{}".format(jobname, str(epoch), extension)]
for checkpoint_path in checkpoint_paths:
to_save = {
"model_build_kwargs": model_build_kwargs,
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"cfg": cfg,
"biapy_version": biapy_version,
}
save_on_master(to_save, checkpoint_path)
if len(checkpoint_paths) > 0:
return checkpoint_paths[0]
[docs]
def save_on_master(model_dict, checkpoint_path):
"""
Save a PyTorch object only if the current process is the main (master) process.
This is a wrapper around `torch.save` to ensure that checkpoints are not
redundantly saved by all processes in a distributed training setup.
Parameters
----------
*args : Any
Positional arguments to pass to `torch.save`.
**kwargs : Any
Keyword arguments to pass to `torch.save`.
"""
if is_main_process():
if str(checkpoint_path).endswith(".pth"):
torch.save(model_dict, checkpoint_path)
elif str(checkpoint_path).endswith(".safetensors"):
from safetensors.torch import save_file
save_file(model_dict["model"], checkpoint_path)
else:
raise ValueError("Unsupported checkpoint extension: {}".format(checkpoint_path))
[docs]
def get_checkpoint_path(cfg, jobname):
"""
Determine the path to the checkpoint file to load.
It selects the checkpoint based on `cfg.PATHS.CHECKPOINT_FILE`,
`cfg.MODEL.LOAD_CHECKPOINT_EPOCH` ("last_on_train" or "best_on_val"),
and the `jobname`.
Parameters
----------
cfg : YACS CN object
The configuration object. Key parameters:
- `cfg.PATHS.CHECKPOINT`: Base directory for checkpoints.
- `cfg.PATHS.CHECKPOINT_FILE`: Explicit path to a checkpoint file (if set).
- `cfg.MODEL.LOAD_CHECKPOINT_EPOCH`: Strategy for selecting checkpoint
("last_on_train" or "best_on_val").
jobname : str
The name of the current job/experiment.
Returns
-------
str
The absolute path to the checkpoint file without the extension (without the .pth or .safetensors).
Raises
------
NotImplementedError
If `cfg.MODEL.LOAD_CHECKPOINT_EPOCH` is an unrecognized value.
"""
checkpoint_dir = Path(cfg.PATHS.CHECKPOINT)
# Select the checkpoint source file
if cfg.PATHS.CHECKPOINT_FILE != "":
resume = cfg.PATHS.CHECKPOINT_FILE
resume, _ = os.path.splitext(resume)
else:
if cfg.MODEL.LOAD_CHECKPOINT_EPOCH == "last_on_train":
all_checkpoints = glob.glob(os.path.join(checkpoint_dir, "{}-checkpoint-*".format(jobname)))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split("-")[-1].split(".")[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
resume = os.path.join(checkpoint_dir, "{}-checkpoint-{}".format(jobname, latest_ckpt))
elif cfg.MODEL.LOAD_CHECKPOINT_EPOCH == "best_on_val":
resume = os.path.join(checkpoint_dir, "{}-checkpoint-best".format(jobname))
else:
raise NotImplementedError
return resume
[docs]
def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=None, just_extract_checkpoint_info=False, skip_unmatched_layers=False) -> Tuple[int | CN | None, str | None]:
"""
Load a model checkpoint from disk.
This function handles loading the model's state dictionary, optimizer state,
and epoch number from a checkpoint file. It can also be configured to
only extract configuration information or to skip layers with mismatched shapes.
Parameters
----------
cfg : YACS CN object
The configuration object. Key parameters:
- `cfg.PATHS.CHECKPOINT_FILE`: Explicit path to checkpoint.
- `cfg.MODEL.LOAD_CHECKPOINT_EPOCH`: Strategy for checkpoint selection.
- `cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT`: List of items to load from the checkpoint (if available). Options are:
- "weights": Load model weights.
- "model_arch": Load model architecture.
- "optimizer": Load optimizer state.
- "epoch": Load epoch number.
jobname : str
The name of the current job/experiment.
model_without_ddp : nn.Module
The model instance (unwrapped if DDP is used) to load weights into.
device : torch.device
The device to map the loaded checkpoint to.
optimizer : Optional[torch.optim.Optimizer], optional
The optimizer instance to load state into. If None, optimizer state is not loaded.
Defaults to None.
just_extract_checkpoint_info : bool, optional
If True, only the configuration (`cfg`) and BiaPy version from the checkpoint
are returned, without loading model or optimizer states. Defaults to False.
skip_unmatched_layers : bool, optional
If True, layers in the checkpoint that have different shapes than the
current model's layers will be skipped during loading. Defaults to False.
Returns
-------
Tuple[int | CN | None, str | None]
If `just_extract_checkpoint_info` is True: returns `(checkpoint_cfg, biapy_version)`.
Otherwise: returns `(start_epoch, resume_path)`.
`checkpoint_cfg` and `biapy_version` can be `None` if not found in the checkpoint.
Raises
------
FileNotFoundError
If the specified checkpoint file does not exist.
"""
start_epoch = 0
resume = get_checkpoint_path(cfg, jobname)
# Take the first existing file with supported extension
for ext in ['.pth', '.safetensors']:
if os.path.exists(resume + ext):
resume += ext
break
if not os.path.exists(resume):
raise FileNotFoundError(f"Checkpoint file {resume} not found (considering .pth and .safetensors extensions)")
else:
if just_extract_checkpoint_info:
print("Extracting model from checkpoint file {}".format(resume))
else:
print("Loading checkpoint from file {}".format(resume))
# Load checkpoint file
torch.serialization.add_safe_globals([CN])
torch.serialization.add_safe_globals([set])
torch.serialization.add_safe_globals([partial])
torch.serialization.add_safe_globals([torch.nn.modules.normalization.LayerNorm])
if resume.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(resume, map_location=device, check_hash=True)
elif resume.endswith(".safetensors"):
from safetensors.torch import load_file
checkpoint = {
"model": load_file(resume, device="cpu")
}
else: # ends with .pth
checkpoint = torch.load(resume, map_location=device, weights_only=True)
if just_extract_checkpoint_info:
if "cfg" not in checkpoint and not resume.endswith(".safetensors"):
print(
"Checkpoint seems to not be from BiaPy (v3.5.1 or later) as model building args couldn't be extracted. Thus, "
"the model will be built based on the current configuration"
)
return (
CN(checkpoint["cfg"]) if "cfg" in checkpoint else None,
str(checkpoint["biapy_version"]) if "biapy_version" in checkpoint else None,
)
if 'model' in checkpoint:
checkpoint_state_dict = checkpoint['model']
elif 'model_state_dict' in checkpoint:
checkpoint_state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
# Common convention in PyTorch Lightning
checkpoint_state_dict = checkpoint['state_dict']
else:
checkpoint_state_dict = checkpoint
if not skip_unmatched_layers:
model_without_ddp.load_state_dict(checkpoint_state_dict, strict=False)
else:
# Filter out layers with mismatched shapes
filtered_state_dict = {}
model_state_dict = model_without_ddp.state_dict()
for k, v in checkpoint_state_dict.items():
if k in model_state_dict:
if torch.is_tensor(v):
if v.shape == model_state_dict[k].shape:
filtered_state_dict[k] = v
else:
print(f"Skipping layer '{k}' due to shape mismatch: checkpoint {v.shape} vs model {model_state_dict[k].shape}")
else:
print(f"Skipping layer '{k}' because its value is not a tensor (type {type(v)})")
else:
print(f"Skipping unexpected layer '{k}' not found in model.")
# Load only matching parameters
model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
print("Model weights loaded!")
# Load also opt, epoch and scaler info
if "optimizer" in checkpoint and optimizer is not None and "optimizer" in cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT:
optimizer.load_state_dict(checkpoint["optimizer"], strict=False)
print("Optimizer info loaded!")
start_epoch = 0
if "epoch" in checkpoint and "epoch" in cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT:
start_epoch = checkpoint["epoch"]
if isinstance(start_epoch, str):
start_epoch = 0
else:
start_epoch = int(start_epoch) # type: ignore
print("Epoch loaded!")
return start_epoch, resume
[docs]
def all_reduce_mean(x):
"""
Perform an all-reduce operation on a scalar or single-element tensor, then computes the mean across all processes in a distributed group.
If not in a distributed environment, returns the input value directly.
Parameters
----------
x : float or torch.Tensor
The scalar value or single-element tensor to be reduced.
Returns
-------
float
The mean of `x` across all processes.
"""
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
[docs]
def time_text(t):
"""
Format a time duration (in seconds) into a human-readable string.
Formats as 'Xh', 'Xm', or 'Xs' depending on the duration.
Parameters
----------
t : float
Time duration in seconds.
Returns
-------
str
Formatted time string.
"""
if t >= 3600:
return "{:.1f}h".format(t / 3600)
elif t >= 60:
return "{:.1f}m".format(t / 60)
else:
return "{:.1f}s".format(t)
[docs]
class TensorboardLogger(object):
"""A simple wrapper for `tensorboardX.SummaryWriter` to log scalar metrics."""
def __init__(self, log_dir):
"""
Initialize the TensorboardLogger.
Parameters
----------
log_dir : str
The directory where TensorBoard log files will be saved.
"""
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0
[docs]
def set_step(self, step=None):
"""
Set the current global step for logging.
If `step` is None, increments the internal step counter.
Parameters
----------
step : Optional[int], optional
The specific step number to set. If None, increments the current step.
Defaults to None.
"""
if step is not None:
self.step = step
else:
self.step += 1
[docs]
def update(self, head="scalar", step=None, **kwargs):
"""
Log scalar values to TensorBoard.
Parameters
----------
head : str, optional
The main category for the scalar (e.g., "train_loss", "val_metrics").
Defaults to "scalar".
step : Optional[int], optional
The specific global step to log this update at. If None, uses the
internal `self.step`. Defaults to None.
**kwargs : float | int | torch.Tensor
Keyword arguments where keys are metric names (e.g., "loss", "accuracy")
and values are the corresponding scalar values (can be PyTorch tensors
or Python floats/ints).
"""
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
[docs]
def flush(self):
"""Ensure all pending events have been written to disk."""
self.writer.flush()
[docs]
class SmoothedValue(object):
"""Track a series of values and provides access to smoothed values (median, average) over a sliding window or the global series average."""
def __init__(self, window_size=20, fmt=None):
"""
Initialize the SmoothedValue tracker.
Parameters
----------
window_size : int, optional
The size of the sliding window for calculating median and average.
Defaults to 20.
fmt : Optional[str], optional
A format string for displaying the value. Placeholders include
`{median}`, `{avg}`, `{global_avg}`, `{max}`, `{value}`.
Defaults to "{median:.4f} ({global_avg:.4f})".
"""
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
self.eps = sys.float_info.epsilon
[docs]
def update(self, value, n=1):
"""
Update the tracker with a new value.
Parameters
----------
value : float
The new value to add.
n : int, optional
The number of samples represented by this value (e.g., batch size).
Defaults to 1.
"""
self.deque.append(value)
self.count += n
self.total += value * n
[docs]
def synchronize_between_processes(self):
"""
Synchronize the `count` and `total` attributes across all processes in a distributed environment using `dist.all_reduce`.
Warning: This method does *not* synchronize the `deque` (sliding window).
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
"""Return the median of the values in the current sliding window."""
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
"""Return the average of the values in the current sliding window."""
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self) -> float:
"""Return the global average of all values recorded since initialization."""
return self.total / (self.count + self.eps)
@property
def max(self):
"""Return the maximum value in the current sliding window."""
return max(self.deque)
@property
def value(self):
"""Return the most recently updated value."""
return self.deque[-1]
def __str__(self):
"""Return a formatted string representation of the smoothed value."""
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
[docs]
class MetricLogger(object):
"""Aggregate and logs various metrics using `SmoothedValue` objects."""
def __init__(self, delimiter="\t", verbose=False):
r"""
Initialize the MetricLogger.
Parameters
----------
delimiter : str, optional
The string used to separate metrics when printing. Defaults to \t".
verbose : bool, optional
If True, additional information (e.g., max GPU memory) is printed.
Defaults to False.
"""
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
self.verbose = verbose
[docs]
def update(self, **kwargs):
"""
Update the values of tracked metrics.
Parameters
----------
**kwargs : float | int | torch.Tensor
Keyword arguments where keys are metric names and values are their
current scalar values.
"""
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
"""Allow direct access to `SmoothedValue` objects via attribute lookup."""
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
"""Return a string representation of all tracked metrics."""
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
[docs]
def synchronize_between_processes(self):
"""Synchronize all tracked `SmoothedValue` meters across distributed processes."""
for meter in self.meters.values():
meter.synchronize_between_processes()
[docs]
def add_meter(self, name, meter):
"""
Add a custom `SmoothedValue` meter to the logger.
Parameters
----------
name : str
The name of the meter.
meter : SmoothedValue
The `SmoothedValue` instance to add.
"""
self.meters[name] = meter
[docs]
def log_every(self, iterable, print_freq, header=None):
"""
Log progress for an iterable, printing metrics at a specified frequency.
Parameters
----------
iterable : Iterable[Any]
The iterable (e.g., DataLoader) to iterate over.
print_freq : int
The frequency (in iterations) at which to print log messages.
header : Optional[str], optional
An optional header string to prepend to log messages. Defaults to None.
Yields
------
Any
Items from the input `iterable`.
"""
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
log_msg = [
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"iter-time: {time}",
]
if torch.cuda.is_available() and self.verbose:
log_msg.append("max mem: {memory:.0f}MB")
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available() and self.verbose:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
[docs]
def update_dict_with_existing_keys(d, u, not_recognized_keys=[], not_recognized_key_vals=[]):
"""
Recursively update a dictionary `d` with values from dictionary `u`, only for keys that already exist in `d`.
This function is useful for updating configuration dictionaries while
ensuring that no new keys are introduced from the update dictionary.
It also tracks keys from `u` that were not found in `d`.
Parameters
----------
d : Dict
The dictionary to be updated (destination).
u : Dict
The dictionary containing update values (source).
not_recognized_keys : Optional[List], optional
A list to append keys from `u` that were not found in `d`.
If None, a new list is created. Defaults to None.
not_recognized_key_vals : Optional[List], optional
A list to append values corresponding to `not_recognized_keys`.
If None, a new list is created. Defaults to None.
Returns
-------
Tuple[Dict, List, List]
- `d`: The updated dictionary.
- `not_recognized_keys`: List of keys from `u` not found in `d`.
- `not_recognized_key_vals`: List of values from `u` corresponding to `not_recognized_keys`.
"""
for k, v in u.items():
if k in d:
if isinstance(v, collections.abc.Mapping):
d[k], _, _ = update_dict_with_existing_keys(
d.get(k, {}), v, not_recognized_keys, not_recognized_key_vals
)
else:
if k in d:
d[k] = v
else:
not_recognized_keys.append(k)
not_recognized_key_vals.append(v)
else:
not_recognized_keys.append(k)
not_recognized_key_vals.append(v)
return d, not_recognized_keys, not_recognized_key_vals
[docs]
def os_walk_clean(
path: str,
exclude_files: Tuple = ("Thumbs.db", "desktop.ini", ".DS_Store"),
exclude_dirs: Tuple = (".git", "__pycache__")
) -> Iterator[Tuple[str, List[str], List[str]]]:
"""
Clean os.walk + robust natural sorting (numeric-aware).
Parameters
----------
path : str
The root directory to walk.
exclude_files : tuple, optional
Filenames to exclude from the results. Defaults to common system files.
exclude_dirs : tuple, optional
Directory names to exclude from the results. Defaults to common system directories.
Yields
------
Iterator[Tuple[str, List[str], List[str]]]
Yields tuples of (root, dirs, files) with excluded items removed and
directories/files sorted in natural order.
"""
def natural_key(s):
# Split filename into chunks of digits and non-digits,
# keeping all chunks as strings but zero-pad digits for proper order.
parts = re.findall(r'\d+|\D+', s)
# Pad numeric chunks so '2' < '10' < '100'
return [p.zfill(10) if p.isdigit() else p.lower() for p in parts]
for root, dirs, files in os.walk(path):
dirs[:] = [d for d in dirs if d not in exclude_dirs and not d.startswith('.')]
files = [f for f in files if f not in exclude_files and not f.startswith('.')]
# Safe natural sort
dirs.sort(key=natural_key)
files.sort(key=natural_key)
yield root, dirs, files
[docs]
def resolve_cpu_budget(user_num_cpus: int) -> int:
"""Total CPU cores budget for the entire job."""
if user_num_cpus == -1:
# If you use CPU affinity / SLURM cpuset, you may want to respect that instead of cpu_count()
return multiprocessing.cpu_count()
return max(1, int(user_num_cpus))
[docs]
def compute_threads_and_workers(
user_num_cpus: int,
world_size: int,
training_samples: Optional[int] = None,
max_workers_cap: int = 8
) -> Tuple[int, int, int, int]:
"""
Compute CPU budget, CPU per rank, main threads, and DataLoader workers per rank.
Parameters
----------
user_num_cpus : int
User-specified number of CPUs (-1 to use all available).
world_size : int
Number of distributed ranks/processes.
training_samples : int, optional
Number of training samples (to limit workers for small datasets).
max_workers_cap : int, optional
Maximum cap on DataLoader workers per rank. Defaults to 8.
Returns
-------
Tuple[int, int, int, int]
- `cpu_budget`: Total CPU cores budget for the job.
- `cpu_per_rank`: CPU cores allocated per rank.
- `main_threads`: Number of main threads for training process.
- `num_workers`: Number of DataLoader workers per rank.
"""
cpu_budget = resolve_cpu_budget(user_num_cpus)
world_size = max(1, int(world_size))
cpu_per_rank = max(1, cpu_budget // world_size)
# Conservative: keep training-process threads modest so DataLoader can breathe
main_threads = min(4, cpu_per_rank)
# Leave 1 core for OS/overhead
workers_per_rank_budget = max(0, cpu_per_rank - main_threads - 1)
num_workers = min(workers_per_rank_budget, max_workers_cap)
# Also don't spawn more workers than you have samples (helps tiny datasets)
if training_samples is not None:
num_workers = min(num_workers, max(0, int(training_samples)))
return cpu_budget, cpu_per_rank, main_threads, num_workers