From df9b080dec1b3daa5f5b29394d0001035720f540 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Fri, 3 Apr 2026 14:59:27 -0400 Subject: [PATCH 1/7] Updated graphcast implementation --- .../GraphCast/data_utils/graphcast_graph.py | 115 +++--- experiments/GraphCast/dataset.py | 6 - experiments/GraphCast/layers.py | 68 ++-- experiments/GraphCast/model.py | 335 ++++++++++-------- 4 files changed, 287 insertions(+), 237 deletions(-) diff --git a/experiments/GraphCast/data_utils/graphcast_graph.py b/experiments/GraphCast/data_utils/graphcast_graph.py index 3d592e7..1137559 100644 --- a/experiments/GraphCast/data_utils/graphcast_graph.py +++ b/experiments/GraphCast/data_utils/graphcast_graph.py @@ -57,29 +57,29 @@ class GraphCastTopology: @dataclass class DistributedGraphCastGraph: + # Distributed environment info rank: int world_size: int ranks_per_graph: int + + # Graph metadata mesh_level: int lat_lon_grid: Tensor + + # Mesh vertex features mesh_graph_node_features: Tensor mesh_graph_edge_features: Tensor - mesh_graph_node_rank_placement: Tensor - mesh_graph_edge_rank_placement: Tensor - mesh_graph_src_indices: Tensor - mesh_graph_dst_indices: Tensor - mesh_graph_src_rank_placement: Tensor - mesh_graph_dst_rank_placement: Tensor - grid_rank_placement: Tensor + + # Grid vertex features mesh2grid_graph_node_features: Tensor - mesh2grid_graph_edge_features: Tensor - mesh2grid_graph_edge_rank_placement: Tensor - mesh2grid_graph_src_indices: Tensor - mesh2grid_graph_dst_indices: Tensor grid2mesh_graph_node_features: Tensor + + # Mesh <--> Grid edge features + mesh2grid_graph_edge_features: Tensor grid2mesh_graph_edge_features: Tensor - grid2mesh_graph_src_indices: Tensor - grid2mesh_graph_dst_indices: Tensor + + # Distributed graph info + distributed_comm_patterns: GraphCastCommPatterns def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatterns: @@ -110,7 +110,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt grid2mesh_cp = build_communication_pattern( global_edge_list=grid2mesh_edges, partitioning=mesh_part, - neighbor_partitioning=grid_part, rank=rank, world_size=world_size, ) @@ -129,7 +128,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt mesh_cp = build_communication_pattern( global_edge_list=mesh_edges, partitioning=mesh_part, - neighbor_partitioning=mesh_part, rank=rank, world_size=world_size, ) @@ -147,7 +145,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt mesh2grid_cp = build_communication_pattern( global_edge_list=mesh2grid_edges, partitioning=grid_part, - neighbor_partitioning=mesh_part, rank=rank, world_size=world_size, ) @@ -220,6 +217,47 @@ def get_mesh_graph_partition(mesh_level: int, world_size: int): mesh_vertex_rank_placement = torch.tensor(mesh_vertex_rank_placement) return mesh_vertex_rank_placement + @staticmethod + def get_grid_vertex_partition( + lat: int, + lon: int, + mesh_vertex_rank_placement: torch.Tensor, + grid2mesh_grid_src_indices: torch.Tensor, + grid2mesh_mesh_dst_indices: torch.Tensor, + mesh2grid_mesh_src_indices: torch.Tensor, + world_size: int, + ) -> torch.Tensor: + """Generate the partitioning of grid vertices to minimize cross-rank edges. + + For each grid vertex, counts how many of its connected mesh vertices + (via both grid2mesh and mesh2grid edges) live on each rank, then assigns + the grid vertex to the rank with the plurality of connections. + + mesh2grid grid destinations are implicit: grid vertex i owns edges + [3i, 3i+1, 3i+2] since create_mesh2grid_graph assigns exactly 3 + edges (one face's vertices) per grid vertex. + """ + num_grid = lat * lon + votes = torch.zeros(num_grid, world_size, dtype=torch.long) + + # --- grid2mesh contribution: grid vertex is src, mesh vertex is dst --- + g2m_ranks = mesh_vertex_rank_placement[grid2mesh_mesh_dst_indices.long()] + # Flatten (grid_vertex, rank) into a 1D index for scatter_add_ + g2m_flat_idx = grid2mesh_grid_src_indices.long() * world_size + g2m_ranks + votes.view(-1).scatter_add_(0, g2m_flat_idx, torch.ones_like(g2m_flat_idx)) + + # --- mesh2grid contribution: mesh vertex is src, grid vertex is dst --- + # Each grid vertex i has exactly 3 mesh2grid edges at positions [3i, 3i+1, 3i+2] + m2g_grid_dst = torch.arange(num_grid, dtype=torch.long).repeat_interleave(3) + m2g_ranks = mesh_vertex_rank_placement[mesh2grid_mesh_src_indices.long()] + m2g_flat_idx = m2g_grid_dst * world_size + m2g_ranks + votes.view(-1).scatter_add_(0, m2g_flat_idx, torch.ones_like(m2g_flat_idx)) + + # Assign each grid vertex to the rank with the most connections + grid_partitioning = votes.argmax(dim=1) + + return grid_partitioning + def get_mesh_graph(self, mesh_vertex_rank_placement: torch.Tensor): """Get the graph for the distributed graphcast graph.""" @@ -327,8 +365,6 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict): grid_vertex_rank_placement ) - # TODO: Consider we can have it to so grid2mesh edges don't require a - # backpropagation. (If encoder is after the gather / scatter) grid2mesh_graph_dict = { "node_features": torch.tensor([]), "edge_features": edge_features, @@ -340,7 +376,7 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict): } return grid2mesh_graph_dict - def get_mesh2grid_graph( + def get_mesh2grid_edges( self, grid_vertex_rank_placement, renumbered_vertices, @@ -359,7 +395,6 @@ def get_mesh2grid_graph( mesh2grid_edge_rank_placement = grid_vertex_rank_placement[dst_grid_indices] mesh2grid_graph_dict = { - "node_features": torch.tensor([]), "edge_features": edge_features, "src_indices": src_mesh_indices, "dst_indices": dst_grid_indices, @@ -401,10 +436,26 @@ def get_graphcast_graph( grid_vertex_rank_placement = grid2mesh_graph["grid_vertex_rank_placement"] renumbered_grid = grid2mesh_graph["renumbered_grid"] - mesh2grid_graph = self.get_mesh2grid_graph( + mesh2grid_graph = self.get_mesh2grid_edges( grid_vertex_rank_placement, renumbered_vertices, renumbered_grid ) + topology = GraphCastTopology( + rank=self.local_rank, + world_size=self.world_size, + ranks_per_graph=self.ranks_per_graph, + mesh_rank_placement=mesh_vertex_rank_placement, + grid_rank_placement=grid_vertex_rank_placement, + mesh_graph_src_indices=mesh_graph["src_indices"], + mesh_graph_dst_indices=mesh_graph["dst_indices"], + mesh2grid_graph_src_indices=mesh2grid_graph["src_indices"], + mesh2grid_graph_dst_indices=mesh2grid_graph["dst_indices"], + grid2mesh_graph_src_indices=grid2mesh_graph["src_indices"], + grid2mesh_graph_dst_indices=grid2mesh_graph["dst_indices"], + ) + + comm_patterns = build_graphcast_comm_patterns(topology) + return DistributedGraphCastGraph( rank=self.rank, world_size=self.world_size, @@ -413,25 +464,9 @@ def get_graphcast_graph( lat_lon_grid=self.lat_lon_grid, mesh_graph_node_features=mesh_graph["node_features"], mesh_graph_edge_features=mesh_graph["edge_features"], - mesh_graph_node_rank_placement=mesh_graph["node_rank_placement"], - mesh_graph_edge_rank_placement=mesh_graph["edge_rank_placement"], - mesh_graph_src_indices=mesh_graph["src_indices"], - mesh_graph_dst_indices=mesh_graph["dst_indices"], - mesh_graph_src_rank_placement=mesh_graph["src_rank_placement"], - mesh_graph_dst_rank_placement=mesh_graph["dst_rank_placement"], - grid_rank_placement=grid2mesh_graph["grid_vertex_rank_placement"], - mesh2grid_graph_node_features=mesh2grid_graph["node_features"], - mesh2grid_graph_edge_features=mesh2grid_graph["edge_features"], - mesh2grid_graph_edge_rank_placement=mesh2grid_graph[ - "mesh2grid_edge_rank_placement" - ], - mesh2grid_graph_src_indices=mesh2grid_graph["src_indices"], - mesh2grid_graph_dst_indices=mesh2grid_graph["dst_indices"], + mesh2grid_graph_node_features=torch.tensor([]), grid2mesh_graph_node_features=grid2mesh_graph["node_features"], + mesh2grid_graph_edge_features=mesh2grid_graph["edge_features"], grid2mesh_graph_edge_features=grid2mesh_graph["edge_features"], - grid2mesh_graph_edge_rank_placement=grid2mesh_graph[ - "grid2mesh_edge_rank_placement" - ], - grid2mesh_graph_src_indices=grid2mesh_graph["src_indices"], - grid2mesh_graph_dst_indices=grid2mesh_graph["dst_indices"], + distributed_comm_patterns=comm_patterns, ) diff --git a/experiments/GraphCast/dataset.py b/experiments/GraphCast/dataset.py index 3e4eb08..b15e700 100644 --- a/experiments/GraphCast/dataset.py +++ b/experiments/GraphCast/dataset.py @@ -261,20 +261,14 @@ def test_synthetic_weather_dataset(num_days, batch_size=1): print("Mesh label:\t", static_graph.mesh_level) print("Mesh Node features:\t", static_graph.mesh_graph_node_features.shape) print("Mesh Edge features:\t", static_graph.mesh_graph_edge_features.shape) - print("Mesh src indices:\t", static_graph.mesh_graph_src_indices.shape) - print("Mesh dst indices:\t", static_graph.mesh_graph_dst_indices.shape) print("=" * 80) print( "mesh2grid edge features:\t", static_graph.mesh2grid_graph_edge_features.shape ) - print("mesh2grid src indices:\t", static_graph.mesh2grid_graph_src_indices.shape) - print("mesh2grid dst indices:\t", static_graph.mesh2grid_graph_dst_indices.shape) print("=" * 80) print( "grid2mesh edge features:\t", static_graph.grid2mesh_graph_edge_features.shape ) - print("grid2mesh src indices:\t", static_graph.grid2mesh_graph_src_indices.shape) - print("grid2mesh dst indices:\t", static_graph.grid2mesh_graph_dst_indices.shape) print("=" * 80) diff --git a/experiments/GraphCast/layers.py b/experiments/GraphCast/layers.py index 524987b..027974b 100644 --- a/experiments/GraphCast/layers.py +++ b/experiments/GraphCast/layers.py @@ -11,14 +11,12 @@ # https://github.com/LBANN and https://github.com/LLNL/LBANN. # # SPDX-License-Identifier: (Apache-2.0) - -from typing import Tuple, Union -import numpy as np import torch import torch.nn as nn -from typing import Optional -from DGraph.Communicator import Communicator -from dist_utils import SingleProcessDummyCommunicator +from DGraph.utils.TimingReport import TimingReport + +""" +Local only layers for mesh processing. These layers do not perform any communication and can be used in both GraphCast and MeshGraphNet.""" class MeshGraphMLP(nn.Module): @@ -72,7 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: The transformed tensor """ - return self._model(x) + with TimingReport("MeshGraphMLP/forward"): + return self._model(x) class MeshNodeBlock(nn.Module): @@ -83,7 +82,6 @@ def __init__( input_node_dim: int, input_edge_dim: int, output_node_dim: int, - comm: Union[Communicator, SingleProcessDummyCommunicator], hidden_dim: int = 512, num_hidden_layers: int = 1, aggregation_type: str = "sum", @@ -102,7 +100,6 @@ def __init__( super(MeshNodeBlock, self).__init__() assert aggregation_type in ["sum"], "Only sum aggregation is supported for now." self.aggregation_type = aggregation_type - self.comm = comm self.mesh_mlp = MeshGraphMLP( input_dim=input_node_dim + input_edge_dim, output_dim=output_node_dim, @@ -115,7 +112,6 @@ def forward( node_features: torch.Tensor, edge_features: torch.Tensor, src_indices: torch.Tensor, - rank_mapping: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the node block @@ -129,17 +125,23 @@ def forward( Returns: The updated node features """ - # Sum all the edge features for each node num_local_nodes = node_features.shape[0] - # TODO: This can be optimized by a fused gather-scatter operation - S.Z - - aggregated_edge_features = self.comm.scatter( - edge_features, src_indices, rank_mapping, num_local_nodes - ) - # Concatenate the node and edge features - x = torch.cat([node_features, aggregated_edge_features], dim=-1) - # Apply the MLP - node_features_new = self.mesh_mlp(x) + node_features + with TimingReport("MeshNodeBlock/scatter_add"): + aggregated_edge_features = torch.zeros( + num_local_nodes, + edge_features.shape[-1], + device=edge_features.device, + dtype=edge_features.dtype, + ) + aggregated_edge_features.scatter_add_( + 0, + src_indices.unsqueeze(-1).expand(-1, edge_features.shape[-1]), + edge_features, + ) + + with TimingReport("MeshNodeBlock/mlp"): + x = torch.cat([node_features, aggregated_edge_features], dim=-1) + node_features_new = self.mesh_mlp(x) + node_features return node_features_new @@ -152,7 +154,6 @@ def __init__( input_dst_node_dim: int, input_edge_dim: int, output_edge_dim: int, - comm: Union[Communicator, SingleProcessDummyCommunicator], hidden_dim: int = 512, num_hidden_layers: int = 1, aggregation_type: str = "sum", @@ -162,7 +163,6 @@ def __init__( input_node_dim (int): The dimensionality of the input node features. input_edge_dim (int): The dimensionality of the input edge features. output_edge_dim (int): The dimensionality of the output edge features. - comm (CommunicatorBase): The communicator to use for distributed training. hidden_dim (int, optional): The dimensionality of the hidden layers. Defaults to 512. aggregation_type (str, optional): The type of aggregation to use. Defaults to "sum". """ @@ -171,7 +171,6 @@ def __init__( super(MeshEdgeBlock, self).__init__() assert aggregation_type in ["sum"], "Only sum aggregation is supported for now." self.aggregation_type = aggregation_type - self.comm = comm self.mesh_mlp = MeshGraphMLP( input_dim=input_src_node_dim + input_dst_node_dim + input_edge_dim, output_dim=output_edge_dim, @@ -186,8 +185,6 @@ def forward( edge_features: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - src_rank_mapping: Optional[torch.Tensor] = None, - dst_rank_mapping: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the edge block @@ -201,16 +198,13 @@ def forward( Returns: The updated edge features """ - # Concatenate the source and destination node features with the edge features - src_node_features = self.comm.gather( - src_node_features, src_indices, src_rank_mapping - ) - dst_node_features = self.comm.gather( - dst_node_features, dst_indices, dst_rank_mapping - ) - concatenated_features = torch.cat( - [src_node_features, dst_node_features, edge_features], dim=-1 - ) - # Apply the MLP - edge_features_new = self.mesh_mlp(concatenated_features) + edge_features + with TimingReport("MeshEdgeBlock/gather"): + src_node_features = src_node_features[src_indices] + dst_node_features = dst_node_features[dst_indices] + + with TimingReport("MeshEdgeBlock/mlp"): + concatenated_features = torch.cat( + [src_node_features, dst_node_features, edge_features], dim=-1 + ) + edge_features_new = self.mesh_mlp(concatenated_features) + edge_features return edge_features_new diff --git a/experiments/GraphCast/model.py b/experiments/GraphCast/model.py index da6459b..9183ad8 100644 --- a/experiments/GraphCast/model.py +++ b/experiments/GraphCast/model.py @@ -14,11 +14,14 @@ import torch import torch.nn as nn -from typing import Optional, Tuple +from typing import Tuple from torch import Tensor from layers import MeshEdgeBlock, MeshGraphMLP, MeshNodeBlock from graphcast_config import Config from data_utils.graphcast_graph import DistributedGraphCastGraph +from DGraph.distributed import HaloExchange +from DGraph.distributed.commInfo import CommunicationPattern +from DGraph.utils.TimingReport import TimingReport class GraphCastEmbedder(nn.Module): @@ -116,55 +119,62 @@ def __init__(self, cfg: Config, comm, *args, **kwargs) -> None: comm: Communicator object """ super().__init__(*args, **kwargs) + hidden_dim = cfg.model.hidden_dim - edge_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, - ) - self.edge_mlp = MeshEdgeBlock(*edge_block_invars) - - node_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + self.exchanger = HaloExchange(comm) + + self.edge_mlp = MeshEdgeBlock( + input_src_node_dim=hidden_dim, + input_dst_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim=hidden_dim, ) - self.mesh_node_mlp = MeshNodeBlock(*node_block_invars) - self.grid_node_mlp = MeshGraphMLP( - input_dim=cfg.model.hidden_dim, output_dim=cfg.model.hidden_dim + self.mesh_node_mlp = MeshNodeBlock( + input_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_node_dim=hidden_dim, + hidden_dim=hidden_dim, ) + self.grid_node_mlp = MeshGraphMLP(input_dim=hidden_dim, output_dim=hidden_dim) def forward( self, grid_node_features: Tensor, mesh_node_features: Tensor, grid2mesh_edge_features: Tensor, - grid2mesh_edge_indices_src: Tensor, - grid2mesh_edge_indices_dst: Tensor, + comm_pattern: CommunicationPattern, ) -> Tuple[Tensor, Tensor]: + # local_edge_list: [E, 2] with [central=mesh, neighbor=grid/halo] + edge_index = comm_pattern.local_edge_list + dst_indices = edge_index[:, 0] # mesh (central, aggregation target) + src_indices = edge_index[:, 1] # grid/halo (neighbor, message source) + num_local = comm_pattern.num_local_vertices + + with TimingReport("encoder/halo_exchange"): + halo_features = self.exchanger(mesh_node_features, comm_pattern) + augmented = torch.cat([mesh_node_features, halo_features], dim=0) + + with TimingReport("encoder/edge_block"): + e_feats = self.edge_mlp( + src_node_features=augmented, + dst_node_features=augmented, + edge_features=grid2mesh_edge_features, + src_indices=src_indices, + dst_indices=dst_indices, + ) - e_feats = self.edge_mlp( - src_node_features=grid_node_features, - dst_node_features=mesh_node_features, - edge_features=grid2mesh_edge_features, - src_indices=grid2mesh_edge_indices_src, - dst_indices=grid2mesh_edge_indices_dst, - ) + with TimingReport("encoder/node_block"): + n_feats = self.mesh_node_mlp( + node_features=augmented[:num_local], + edge_features=e_feats, + src_indices=dst_indices, + ) - n_feats = self.mesh_node_mlp( - node_features=mesh_node_features, - edge_features=e_feats, - src_indices=grid2mesh_edge_indices_dst, - ) + with TimingReport("encoder/grid_mlp"): + grid_node_features = grid_node_features + self.grid_node_mlp(grid_node_features) mesh_node_features = mesh_node_features + n_feats - grid_node_features = grid_node_features + self.grid_node_mlp(grid_node_features) - return grid_node_features, mesh_node_features @@ -179,54 +189,73 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): comm: Communicator object """ super().__init__() + hidden_dim = cfg.model.hidden_dim processor_layers = cfg.model.processor_layers - node_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + + self.exchanger = HaloExchange(comm) + + self.edge_processors = nn.ModuleList( + [ + MeshEdgeBlock( + input_src_node_dim=hidden_dim, + input_dst_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim=hidden_dim, + ) + for _ in range(processor_layers) + ] ) - edge_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + self.node_processors = nn.ModuleList( + [ + MeshNodeBlock( + input_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_node_dim=hidden_dim, + hidden_dim=hidden_dim, + ) + for _ in range(processor_layers) + ] ) - edge_layers = [] - node_layers = [] - for _ in range(processor_layers): - edge_layers.append(MeshEdgeBlock(*edge_block_invars)) - for _ in range(processor_layers): - node_layers.append(MeshNodeBlock(*node_block_invars)) - - self.edge_processors = nn.ModuleList(edge_layers) - self.node_processors = nn.ModuleList(node_layers) def forward( self, embedded_mesh_features: Tensor, embedded_mesh2mesh_edge_features: Tensor, - mesh2mesh_edge_indices_src: Tensor, - mesh2mesh_edge_indices_dst: Tensor, + comm_pattern: CommunicationPattern, ) -> Tuple[Tensor, Tensor]: e_feats = embedded_mesh2mesh_edge_features n_feats = embedded_mesh_features - for edge_layer, node_layer in zip(self.edge_processors, self.node_processors): - e_feats = edge_layer( - n_feats, - n_feats, - e_feats, - mesh2mesh_edge_indices_src, - mesh2mesh_edge_indices_dst, - ) - n_feats = node_layer( - n_feats, - e_feats, - mesh2mesh_edge_indices_src, - ) + + # local_edge_list: [E, 2] with [central=mesh_dst, neighbor=mesh_src] + edge_index = comm_pattern.local_edge_list + dst_indices = edge_index[:, 0] # central (aggregation target) + src_indices = edge_index[:, 1] # neighbor (message source) + num_local = comm_pattern.num_local_vertices + + for i, (edge_layer, node_layer) in enumerate( + zip(self.edge_processors, self.node_processors) + ): + with TimingReport(f"processor/layer_{i}/halo_exchange"): + halo_features = self.exchanger(n_feats, comm_pattern) + augmented = torch.cat([n_feats, halo_features], dim=0) + + with TimingReport(f"processor/layer_{i}/edge_block"): + e_feats = edge_layer( + src_node_features=augmented, + dst_node_features=augmented, + edge_features=e_feats, + src_indices=src_indices, + dst_indices=dst_indices, + ) + + with TimingReport(f"processor/layer_{i}/node_block"): + n_feats = node_layer( + node_features=augmented[:num_local], + edge_features=e_feats, + src_indices=dst_indices, + ) + return n_feats, e_feats @@ -243,26 +272,22 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): comm: Communicator object """ super().__init__() - edge_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + hidden_dim = cfg.model.hidden_dim + + self.exchanger = HaloExchange(comm) + + self.edge_mlp = MeshEdgeBlock( + input_src_node_dim=hidden_dim, + input_dst_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim=hidden_dim, ) - self.comm = comm - self.edge_mlp = MeshEdgeBlock(*edge_block_invars) - dst_node_input_dim = cfg.model.hidden_dim - dst_node_output_dim = cfg.model.hidden_dim - m2g_edge_output_dim = cfg.model.hidden_dim self.node_mlp = MeshNodeBlock( - input_node_dim=dst_node_input_dim, - input_edge_dim=m2g_edge_output_dim, - output_node_dim=dst_node_output_dim, - hidden_dim=cfg.model.hidden_dim, - comm=comm, - num_hidden_layers=1, + input_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_node_dim=hidden_dim, + hidden_dim=hidden_dim, ) def forward( @@ -270,42 +295,47 @@ def forward( mesh2grid_edge_features: Tensor, grid_node_features: Tensor, mesh_node_features: Tensor, - mesh2grid_edge_indices_src: Tensor, - mesh2grid_edge_indices_dst: Tensor, + comm_pattern: CommunicationPattern, ) -> Tensor: """ Args: mesh2grid_edge_features (Tensor): The edge features from the mesh to the grid grid_node_features (Tensor): The grid node features mesh_node_features (Tensor): The mesh node features - mesh2grid_edge_indices_src (Tensor): The source indices for the mesh2grid - bipartitate edges. These are the indices - of the mesh nodes that are connected to - the grid nodes. - mesh2grid_edge_indices_dst (Tensor): The destination indices for the mesh2grid - bipartitate edges. These are the indices of - the grid nodes that are connected to the - mesh nodes. + comm_pattern (CommunicationPattern): Precomputed communication pattern + for the mesh2grid bipartite graph (partitioned by grid vertex placement). Returns: (Tensor): The updated grid node features """ - e_feats = self.edge_mlp( - src_node_features=mesh_node_features, - dst_node_features=grid_node_features, - edge_features=mesh2grid_edge_features, - src_indices=mesh2grid_edge_indices_src, - dst_indices=mesh2grid_edge_indices_dst, - ) - n_feats = self.node_mlp( - node_features=grid_node_features, - edge_features=e_feats, - src_indices=mesh2grid_edge_indices_dst, - ) + # local_edge_list: [E, 2] with [central=grid, neighbor=mesh/halo] + edge_index = comm_pattern.local_edge_list + dst_indices = edge_index[:, 0] # grid (central, aggregation target) + src_indices = edge_index[:, 1] # mesh/halo (neighbor, message source) + num_local = comm_pattern.num_local_vertices + + with TimingReport("decoder/halo_exchange"): + # Mesh nodes are the neighbors (sources); grid nodes are the central (destination). + halo_mesh_features = self.exchanger(mesh_node_features, comm_pattern) + augmented_mesh = torch.cat([mesh_node_features, halo_mesh_features], dim=0) + + with TimingReport("decoder/edge_block"): + e_feats = self.edge_mlp( + src_node_features=augmented_mesh, # mesh features (local + halo) + dst_node_features=grid_node_features, # grid features (destination side) + edge_features=mesh2grid_edge_features, + src_indices=src_indices, + dst_indices=dst_indices, + ) - n_feats = grid_node_features + n_feats + with TimingReport("decoder/node_block"): + n_feats = self.node_mlp( + node_features=grid_node_features[:num_local], # local grid nodes being updated + edge_features=e_feats, + src_indices=dst_indices, + ) - return n_feats + return grid_node_features + n_feats class DGraphCast(nn.Module): @@ -320,8 +350,7 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): super().__init__() self.hidden_dim = cfg.model.hidden_dim self.output_grid_dim = cfg.model.output_grid_dim - self.comm = comm - self.embedder = GraphCastEmbedder(cfg=cfg, comm=comm, *args, **kwargs) + self.embedder = GraphCastEmbedder(cfg=cfg, *args, **kwargs) self.encoder = GraphCastEncoder(cfg=cfg, comm=comm, *args, **kwargs) self.processor = GraphCastProcessor(cfg=cfg, comm=comm, *args, **kwargs) self.decoder = GraphCastDecoder(cfg=cfg, comm=comm, *args, **kwargs) @@ -340,26 +369,21 @@ def forward( Returns: (Tensor): The predicted output grid """ - input_grid_features = input_grid_features.squeeze(0) input_mesh_features = static_graph.mesh_graph_node_features mesh2mesh_edge_features = static_graph.mesh_graph_edge_features grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features - mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices - mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices - mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices - mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices - grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices - grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices - - out = self.embedder( - input_grid_features, - input_mesh_features, - mesh2mesh_edge_features, - grid2mesh_edge_features, - mesh2grid_edge_features, - ) + comm_patterns = static_graph.distributed_comm_patterns + + with TimingReport("model/embed"): + out = self.embedder( + input_grid_features, + input_mesh_features, + mesh2mesh_edge_features, + grid2mesh_edge_features, + mesh2grid_edge_features, + ) ( embedded_grid_features, embedded_mesh_features, @@ -367,28 +391,31 @@ def forward( embedded_grid2mesh_edge_features, embedded_mesh2grid_edge_features, ) = out - encoded_grid_features, encoded_mesh_features = self.encoder( - embedded_grid_features, - embedded_mesh_features, - embedded_grid2mesh_edge_features, - grid2mesh_edge_indices_src, - grid2mesh_edge_indices_dst, - ) - out = self.processor( - encoded_mesh_features, - embedded_mesh2mesh_edge_features, - mesh2mesh_edge_indices_src, - mesh2mesh_edge_indices_dst, - ) - processed_mesh_node_features, _ = out - x = self.decoder( - embedded_mesh2grid_edge_features, - encoded_grid_features, - processed_mesh_node_features, - mesh2grid_edge_indices_src, - mesh2grid_edge_indices_dst, - ) - output = self.final_prediction(x) + with TimingReport("model/encode"): + encoded_grid_features, encoded_mesh_features = self.encoder( + embedded_grid_features, + embedded_mesh_features, + embedded_grid2mesh_edge_features, + comm_patterns.grid2mesh, + ) + + with TimingReport("model/process"): + processed_mesh_node_features, _ = self.processor( + encoded_mesh_features, + embedded_mesh2mesh_edge_features, + comm_patterns.mesh, + ) + + with TimingReport("model/decode"): + x = self.decoder( + embedded_mesh2grid_edge_features, + encoded_grid_features, + processed_mesh_node_features, + comm_patterns.mesh2grid, + ) + + with TimingReport("model/final_prediction"): + output = self.final_prediction(x) output = input_grid_features + output return output From ec9b6b8e2463222b9465266d3ba0a20cf29646df Mon Sep 17 00:00:00 2001 From: Shehtab Date: Mon, 13 Apr 2026 02:35:57 -0400 Subject: [PATCH 2/7] Add compute benchmark pdf --- .../cost_model_benchmarks/figures/compute.pdf | Bin 0 -> 19308 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 experiments/cost_model_benchmarks/figures/compute.pdf diff --git a/experiments/cost_model_benchmarks/figures/compute.pdf b/experiments/cost_model_benchmarks/figures/compute.pdf new file mode 100644 index 0000000000000000000000000000000000000000..f60f93169802441205acbaa7ec6f0d27bcad62cb GIT binary patch literal 19308 zcmdtK1yof{^f#;$%9T(`QNRlb2okrrmrF@^ch@DAZc!;gK)OLfDFp-pDN#VWQ)vMy zB?J*fN|g5;)aNnye~;g{zV)tmz1%g-oP8$t?7h$I-^`pj%*x`DoNz7_ggJi<`tT_P z4uwJOO)oaEO zMeb+d)>bmTY>sn*V!mBUc(_QZxtO?^Ls8!d$|f!@=1z7{9`F*vt!8CnW^HE)MSs8R zWRFubcY*2w&5BC`teAVaK)GdX0Tx8R$70`O8K}We?7;u*0LVAUy^}e>-8c5RRn494 zU7c`1e?b0$_7%*{tW8AiJ%AQrz>fz8L7KVl*;BYPs5{}^k+Ts=idI?nJ z48?xqRm9HD9=L*o>mbK598&bx+sVVd6EXpG4!Kz7ugAYvAy$^PjHX+{G};5BgGh| zIJT+RHp+xq%~z=!>K{YxnMIS-rHPEUbpxmu&&|d6T0M*ofj7mvCbnIBqrCH_VfWI< z&5s{@NVNR5)0{qXzr=LB)(wE|aW8J(_U^pn)C ziGeKBorDhNMv9Zhv!ShLsFL0goS9BoklKBocIoA$1NAB4sM*wwpki6as2h8Bo~zM% z(oN#cJ={hLJ>F8zcZ|NeKFzfJ*a39OrM*ojc5rva-7|m<{{dc&*}&S+jzMAln(Z^n z&(m+#2N)NKm=w&^HyD)LD16tOc48M{qgf1=YEkPR_4`FY~pHu>QUVu zcU9cBOFW2UsrARR8T^u`I}5N$BJSZ z>dH@OOg7Y!G7htf4|2XdCHkd*I{LAY^^9yPq?he|qE+IPNR3b|gTxS@#yzDmqCS`- z(cWSvCzJX180&JRMGEY^;`-=_YW5KqS$>u?EL{BNcq7`2>b^pbtl~3`=H^wa>F3o5 z_?qY_NGPPB6>Qc9j9igIS0AY35fH{@qI0?2F21<B1@spdO}2*(4); z@mg7GlQaXozBCV6amcu?2f};=X(}WXR;Wk$Z6m4|eb#Gslubu^J6`u3j*%(0{yS^O=-+W#nx1AqNcD;9)wpw$`(UKIiZKG;;bRyJvq9#Pb z8%>JyLqC68e3RaNNm;RAJF)gHrRrICUw6V96OJhR?Qk)rNyk3cWwzP%cl3|@wm$HO zMPaulN9VP+qUM+?wrpyS1kxUfdrCk>>xOfDP*z*5OJFp>KGJOcDeLf?&DSg8vON1AEHy7hVHS8NOs@&9}haKIoRTHcXUOXLO*oIN&E*H7wbo<$Z z&-fD$wGf)sB(H_5sV|w(rWd``TQ#hXE?%_Vl3XN#c)t_=rI;5&d;Nu1Rh~zoVVl)Q z9zIR`MlK)u{?A|kDMh&0#0S+OPX4@60A|wuguec2Fx1`r?&{ajlouTn>}NPIU28K_ zwgImNXUU#i-B!@Y8sA!qkc+s^Cjo0q?CcRdRP)Z>14Z3Ay3Ge(m#%w6s} zmXw9B(D}vKp=qNUAust{bE)LzQSuWtQ2J*BCmtlNkJH~;kNmO|wiL&l zZ`yV_V7I-SilvD|`^5|#tHspfqK_>~U9~kwNK!I0ZKW9Q%1M}4vskXZzp62L znCLv_#knX0g!BEG6i@Gk>x~90DmuIj$Im4dt{4?o$UQN#)46(ex@Pv`9+`o&*HuSLjFo0d<=z_FmWxduP;z8i7e&(a#Tx)1qs#H0JmVbC9HknpPf=U`URxWf#Whb+x()+W#ese1& zbe1up8__3z7bbamz5i8!ISZqGBpIKDcOwU~5Bk=ub5L@q6G1^|^5r%|LvrQ|TNX*k z=_jyTbdH#$pgKLN&@9&MD=ce{H6jVlf!XFPx5R@50%Gn2$%{Yuz?B|&`^Dai&=A+D z$5oD28%dbM<~Ke%Iy>->XPmy@_$B{xOhuIYndRAqs8a$_t!&Y(q&V`3FU#UVChT06 zho0&5jy~y(_*~iTQ9?UW>YTt_q3?tdTYVhHheF7seFD2+zUG{h2IAwmCDb$O}106JJ?hFzv2?v6JKQa?{{;pW{nfc*mm%>p{zo zh3ooHUA_qKS}g3=Pa+YXzWnyTrZj(m)eAI^*Huy#^9?>7BtA})aJm-36>DE0WTz;<@Z9yLv9!LzG@=~pl`Wx zxgmx7TC>w(UxUj|mo%&jR`rOeto^)6>NM8hZ<~*vdg;}jbjhtd>CszSdHPb#yY?Lx zUr&Cn@d&F)-EyqI;?h<(B=`PCiH61a)ewQk^X0N3!L0ZCZ8BdDp!wF0EmVfRpV;Ma zwqJg5?pYUS{gLNVp3In0lImmEmP#Fj#^pcDG|nXtd>EQ{lFOhj{1nratouAMt9>|! z`2AY1R=^EW#+K_>`;f+rwe+-!6}K>CyZ+gRh7}U*N4m8O)+n$(seMuLQjl^v<{Fa4 zqQ(>26OWuP@mMv}##&Lwo|J>!8|>)#SdcTng|2tnF1z2#jK|R3DN0_eplYi?o

Y z(N--%_QY?0S?jDDZ+Hj{XRp0Kxn5(v@`*a~5rJni za}jS{w1#4yc9#VwtB7y^!Xq#HIdrY_cKG_xjZrpMR4NK`LV8ZG-cr5Ap2)H&G5ryq z5{>=Y(ib|O1jn+tCmyH=u9l~*d<`$7^Jw?sk3XHI<0HIt#7ow7DNkg#{f+T!U)C^> z$B7+TdOpH@OkR?%w5z8nE7P0tl=%H|40hQ#fP79{F!S==)*HG3#vL!v!q936;CR_e z8(V~5@C!q4K$LO`*-Xp8UB4%7 zJYNW|{*s*9_Zo6e46TVg{a#2ofF?LLuS|G*EcbnoMGE4m6GcFYFacjdQFyRqp4{+) zH14qP`XYb&nViV2I+`g?WnaY`E#gcASL32fse$TiDDt8pj-0I31o$r=0+8sQXy6Nx zJIBAVBpi@N8NXX^0Cd?d7-VX#EUWql@23UsE*$gzr&}N2#n&5ulMNwzsN%u}Izs#x zxgSC=XVvHUB~arXS9M<3dh?(u| zxy_5cw>KAUFFc%%cg(k_iAgSlU+W@R8REE*ZPCW7j?^oRaTQv2KLY?34z_@|w}4C` ziWV=tOpGSIn&Olp)YoyJB)`UqRdv25;_`TID#FGr*PMRGR+Ov+-IDtDw)>U%SW)K& zUiFfh1WNsh*Ci)CDlBB!P?%6Lp3gZB!C9Nj`m1(b(W+K{jdUt>MO zxhvQ^>0#ezh1l9>JaV$>l_phdA_!k2cW+gutcVES*M`|L+)6pGY+xw7{&&X{`3#<) zVxOpV3O4Fj;;JY#Jsi#wZP219x4Sq!-9A~?eto#)i7ZiK>|1xOa_-^$DnsJ40(%=m zG3%yGL$%#!EeEN}JUH3R%>CR>A?FCoHPv*-Z*c5+R+c6T6eZt`;L=aCi75&jwNH~& zj6paSdkZ&|ES#!*zZL{O%ev6c?*lr%!l~+K4!D@Nny@{nu5SGj*5FizST5 zh~vG^A71pQk@(!CaO4uhR~;i#FXql<|K%XD8fLs3vDZHjzSH6=sqO3X()a z!R5mm)2SuLyn~~!KdftvPtbjNIl@%wM)-ATYG~KOJqMn|pe5(ntsC7o;fW!2_~q6* zEw$7I@N+6!fkc7Zh|9JkT(dbYWr;qLD_bF}MJxk2zh2Ja6CJNi4V+8X7vXh=r>q?D zgGi5`iXalId*dk_EO>uHo`v`PC^XIUV`0&N+4LpEKzwocmhn*S)0gUE+~s_d$z+g< z+*{4gO)c|xv@Ooux#&s!{C&S znHUQ~NlR&Q?N_-ktPEaN>b4t3$5f~6Ryd#biwk;?FfLDpb2wEI=EnD!O`oOBVla46 zzjJf6vdM&{O^A zJ#DZWCN0&kP|? z2E)C}Qrs4@SjuoW2d!2zp2VY4Y`(6s3*qVbdBXQ$&yj1>^zOWlQT#2 zZ!ek7_b@fhkC<%LGFY#Z(ep>dpwV%+q+u zx_V*pj!*#9n&|vwDo-EH=S3+OL;SI+4dyPMsn&PK0k($Eq8cHTA6pP^uy%|ChLb?5&8b>_~bl~zKW>oyPk`~34!DR7LopAF^NY-#_kaO*t0pRufo|ytSd=*i|JqX%5JG}ZVALwZ2rx0_>&0JW-|4X-C`xokyn9@5WWDYlhbDvbReOj&oEfj`YipXIL)D z=(`yz*;*T&cBRA+M3J83uR6NT_?Clu-QH^wmN|7HxygVBH~ z7nL3&zklm3oW5FEf0xc2%RPY4v|znM}146G4G{OmK;ZtbD?{HCi)w?)X7 z_X?TmD3|L_Fv;mpzi$+$HY`^j5)GVbkTfJp_FcI-n6h=Ic>2<)WLDfIbv-BrT;;rr z3g*nFG^PsD=auEBmA9$=@ud63Ekf(ctqPV-IE%To zLw$uX3`fPS!wmzno^Ld`+M=G7p-WivZc5$vUy~_0bLSAI&|y0XC0Yd>k}es z;&Yj!0daWte#NY=#s-FXRHD>0X0xyOPld=@P(Q!Jl1uU8yyJv;I=_N!Dp~2vSCavA zc%K&NmJZZB|_?;VJ3CKibMd+wjDG!dp+I$Jl-n^)-th0?elYpCA1ZWof!A+osDdX0vD;p0)#N!*JAu_=?zPkkfLun1$3v6+`? zMR!{{&7_RwsZYpXscxQ5c5;8#eq}VizB;gdhj&ID`;h%g6#8V6#FevklJrYOJ(&-! zI&|2_Frf+t6mbp7P49Ka+oo@?6zb0maMX2fHdH*PJ*PE3&)zcJ?mwa$^Y#VyQ(=7H zjzeikfajFdH5O?H*N0Gz?7{9)HP0_%!B-3GY{|%nP?^N2nz($3Kmk)KRoT-ra{B+2dRuq(dxedK)J%RrTJZ z@8{tn!nkN<-oKi--1_)vZelk-n?=z(7ga91Q-KWht!p_S-eeA65$)M3y2!jP zJ$wu4#QNT<`E7sebmZ5GFnqa)M^+3UCkyQcxs0pS_#{lt6rJ*|Lo^gOl__6VHdGUz z_+%p?#^xL~uxQlKef&zi4Yh~t4gr1UIEx&q=E$oY#gzhqk>1U7xK3=_!2%f1>Cf--n%3?US;2hpzCE z1s>4;0hr{k9u^SC{o`f_-&e6V-3#2<8&YkjJSNOs{)E{S$8>}nzJ(1M+luDNWbv|e z(%R)v3XrNVn3SRPO_(xt^XoXgh{naGXCWd`T7fX{C2hS*9(5+whwolKj$B}{X9yc=yq*U<_$&ATzr15`3zM| zlQyBCN~Ik9M&zPRI%|5>g}XOxgf0h(CtR?c46duY&`Y{hxokPD!C+-D{=gQOBPrYR z@;QluLiDk1aU#od)O+f$-aPVP~ z`1yke<>=vU@2Am8ijf}}+Fnx7C3z$qMFxP?yYT8n; zbYg#(@`hNbvwGHgTR7tFJZ}_hTFA=AD-JiT7wNm>=r9JVxh-~2KlIIvrK>AP&Nzwe z8H{hyOFWxrhMpy&9#k=XpYq1vsM^hizU%_Iv*;ZWE((UO1@`+V?kMM%WZqL%amW>V zFrRqmHZ*p%hH%8}=L zPx?v`!($6)m<#Sb@_im6x@tH5?q)sop_nSP+mkHFW$8XTih_?W1$oH?gA1>ffD@)H zxZD>!%7Xn|5Ak$hcGrWNW4RTr#~Lj-3!PF`4Q|W68{2-# zwmlZYXVHhxL0&m&Mxnsl@k~ZFlqMFjNPpjr@kYKtYbM{knL>9yQ_{#wrB_K?v49=J@XT35;p6Pf-w~g#2&M~T~Y^mmOnTi?&o$X0K`9U!K3b{^k^(vMl zGs(H;=(&I=_{j9bxEnHv%AB`x3>1pQuC%1OS@ln@m^Mu7_6dZI*?pF@FlX^6i7*o# zJM9w$eRt&XX^vPlngXHIjkiqFMxtvxlzUbs#P4G4AdyRkW@BFGWXan>H;fqP+GNg+ z)g3F%1ml1{xlh_2Mgc;Jxk+yCqeVTZ6!;lC=^LNrC$fy!a3^2rnoj_bSo+EwlPjI~as^mF> zI2J=H?W3}|G35y1>M1AAwlM-S2UhRVUiFmtm+sb4G8$}SS>#vuNG(|+bFTK`Qa6v| zJ~Ig%z~up;6p8wakwrrF8~%H)=HM$ml@R%{G!OK`ddczwU6Vw{HXa>n_s!aiDw&zi z55+3oU=+*0zl(A@&v<`JhKRr|E~4d%Tz=S6R1W3c(bbvab;w#P+XNP(mlp*5HY_Ea zBt!c{c=;HjbzflpreEBn*PYSF6i?eS6<2D#et&K%gx|HePHgE|16^ZbPk zA<;;6z)=^g!w_f{7#elNlp=sqjO*U9AG2iPJf|o3+MSJo2=U~_A%fGShyz+b0Mlar z0@5N870SR29d(%e7Bvgxtp(x!480~dxKfnslr(W;`{5E|pHk)&zaUGWM5yhFCMdOd zaZ)giI9_=;3Bn%XXte=FmrBW;Mso+hl1 zkp7|PJqF&JPwNV%jN6~OP^_rU3^b>Z6XkfwD#SlM?%7{j(_dpX`Rv_Wxq;wIT|K8# z$)3^Cj32<-0f_2v)(s96Rs0D}u2b?HM>~)7#XCS4)9T^~4^>){upZFtK|qR!_rFCV z5e3XTpvS~F)BKoxc4q}*WX<7~duWnwGqhVm&KG}Kru*a^bCg$MH`>;^%*q$sv6}p6 z2`D8lB(`bH7aw`Zuhy7!CHV3?ocZjNm;8iHXS)rsxn_=;p|0w#><;ohai3Jix7~DN zZ%C1vLx%I6X>BIkZ5vLrN+uV)C*2H&DzU0bH7!F&J{)BYyAk*BqC=|9MxPvv6#t7T zMdECT@H=AP<}S7MS@NZs8MirKWBaQorAJP*MlqbG^K(np-KU6!bMFKn)x>hf ztkeDX-P>XgpNu{|Hyy_&2m7|Q(;e$tR2$ny^JhGL+V(Z8Wh{Tl>*A7ziM5Vg2T?`p zGYes&F*WSyL%~buAbcAK^z|Tk1pf=O5mc@W!-o@h7zCWruJv5bwWJP_yT7+Cy}iNM zM=hbs&XW|)#_bx*tfMx;cRjw5^y;lQc4la1b_wp7BIeCr$m)&Yx0IXX1Mwwds01sl z*Ks6!-dVbra_Bn_>=aM$jWr;5NNip3Gz002^`O03#zW;$rxvo)#C1s&A&kVb3WR%) zoFeWThMW~@SQVp;iY9%_v#0%K&6%W7ErK=5&c0^aFeR51?<8&Xohl{+?P4zSlh!;U z3O%*DoIcMg48NW&WqL|1yb;{!ct>lnntZKsti0Ayzuan&xLh+&+M z{OVmWe~%9TA780m@vo9J4~LUO)w4zw2W_{h^CB|kY0*B zLm7dg^wuDyz*z+|zqFPhx+gJh^LQ(&dd05uhnG)ioHr;T#uXPlbr3Yc=V|h!dix26 z3yzk$t0DL9FagEZi*4`sb}Mc-4jVVY_Lug5-3wsD zes9lX0j1mb0$uj5Fh3w5PK33wj!@r>3Kf$#UP{DGF{>*T2%TrDcq)TEk*kcsSF(~P zMy6!RO~x(MmRU<>*QQX^Z(b!<`MG+hF>G9N5aDeiDP$K=G(n3J~VtBotRw);^f zp6Yjc-SectDV>78{9;1}e^ZcCGL7}tmz6Axfzg-;Y~WAz0h;==f9StZpF|{<91ngB z-?`e4Ra)O~Gryy(Zf7s|=%7@_!VwhhQ)-2FPf9DHg<{F{r@dF^Ib@^7>!_pYlpo-N z_%BGyOXH%Y%js`2-O(;FNK8^^FHxN;wQj0^KB7C&Djwz#wi~=asJ^STs2}@5;bvhp zeCM=!75~jpnJq(!FmeNBUJOUrXF^WGS^9TKclSjeYnR zZQ0%cJAKL9Op$kU83pcTwRN&UrXy15o!IeA7VkWAaXPEv_#my5Cpn}7DP!@jM!Gnu z{;C0Q$yZ`Ei?XY%Jx&?@xyM$DE#-Q9v5(TK<8`cDk}0aP{Ck=el@x2Vknb3|MClu4#UugHzmRMhTx(tQ2A_I{P$>+4|-=j+Y~sO10@g+%;?{qX>Zp8#0#?O}T& zoc%$m`(s{SNX=*aD-_jK@-X8WHg264ay9+Tud*9v_0DZF7mH3rEZ^x=ZQ7(Rq3q!- zl#KFFg7R$DO_@VCFK->C&qhL(=Y*S|nIcs2 z{LT+MTRR>Lc47(k&Z?YQ+t!lJp-BTdr671;?W83 zZWN2I->Quk5R!to`7EJ>Z8Hd1=IFDTI~ytaidv3TS=<@Wm^ZGc4`d6~I$n}PZ?mRH z5tl_*Fher}o9KY4nWj$_5%j+`HTQe}W`?J*K4@7b#-PFBeSDYuu|BJ>XQ|sKaa7Sh zQV)vBNvwHmgKE|U*mK&7L0pOfT{P~thdsK)Z^ zs=IXr2B?=Uxonrljgn(#-OU3UiP|2LFRR{?&Z*L57+6U{6jtz9;C@6d`uWrofLl?LZ=x_?Ip5$Pb%!j#%XnJXx}oej6-O!a+Hso@2E5}_8Nzq%DHvq z1vth0=vj#|5ntx$L~PSjm-W|My~X#w4u@i%^RX}J>FVHK2%TMyZRS8Mcxtb^d>m}K zc^tKpu{$?FOpY=Si*7zzFp9mEk))#OQ6UyX(gkl}_(+=8Ktm+EBS85)f|-IYw)HhB zNxHWYeP(D1dg932j7zpuM@gqk8VoCasbo)GiQYV%D?jhMuvfTRRx>noj&9kf`2gh{ z1SS8*TXU+&IN^yLd)9Mt-QVKDYh1#`z}dZ_W9>%?NkW8Obf@oy-i2~#RlSLNXw~o< z3HN(V5a@Mr$X0t+SRE>ZQ|3u`TlHM0A9hS6>WmpIBDTmU+YMzKn{S$e(^ysvnfISq zGAI;G5^2lVZ)B!8B|l)oWJ!7{3BgtBEQo>5NopsUMA~VdL6V1~0;CxN@{+T<@kPi( zA4!IxohUzNb4R$`(g@3eFFAAC8ucjfUg=%Z&R?62r8jWoVm-y3PrLpJjdx1LI)6=A zy7r45*(c}6F|R%3=DMk0&RpwiTdP+$Zb}}q6GjhAzpe_J`5am}H@CfuXJo7wbbt&G zz?yK(UpPvINEuSG{k^05Q#t;~+Bt#ec~7lYo)?H6b@knyg*pfb^g5pBaxt3^4ZoO7 zJ5yT%#ZC9|?zKG;YJ`i7XK-pKxV4OQk~`NU;77_O+eQR3YROvkkCB@toU5sYG8#l| zYkt&;S?{UfACGuiYGEXpQ5(E;tf}a*md}@`ExvN~+lDV^6WEz?6oG5Li;Svh{`3=N z>Pt>}w@b#8FFH`I@AiqkUnVGH5_o+;w-3OvynnI6$OAxZ9~iOtHKnPO)6YkaGC+o( zH_15QUXRBZs!j{P;X4_A-Rzyop`LKKaVUju`2YqFK%xlvU+s-zNr5Th(kAe_$@x%j52cI?U0%=FA?+6?g z8|uqZ5j#cS+n=p=p=bK5=);86{7zkPseE@3JFGm?eO~p6z_CVFA(JY%_s9O z<5ls0fhit3MB&(R#O$tYY{_h+eerq;q1MjV!*`|Bau?h4{QVQVu~iVz-czo->d#sUi8}uHcmJUV2=|8%;4b|EDAVw2KYbV z(_=5!D z7Pq#rFb7W7fk)}|px~abvx5l^IOPZQ&jCDgXJc;R@;&KfZE58KMFBQsb=ASj+zyHY zX5F6VPWDi4dpn?0F780e#mdRt9Kvk@9M6MtTUfgRbZ3BHG_W};21Nr;hy>6ND7O?8 z0|-L~3cf?207wN2zSg4xxy>Nl=D^83Agd)10WDht5s={ul-mZ% zZ42eL1KNhd;lOwUGzTaM;MEnl21>gFVgtygfUXR{Q+?c8)@CkNz&Stg!T9+EXkqWdcqia6uIV~kh?z#-TIc)^K40{fR@CJxf(`x1sh zxV3*I0T}_uH?46Zc9u5g0OQ-2$27; z^&kO6hysih1@3>bJ2Pz06-0t6@;*hJ>x z0V)BWIt*ah7+!z@!1DHE@ILUa9vV2FiUG=lbz*_x9?koW2FClHL1G|a3^DQ zfv~R+u(@vm!t-5c|5WuY0WkQj3gA^ouow@>6~H-I95jO;>iQ7_DueuBWWNR!2FL`w zPzkaR+QC2dgSkNBzA*tcf!F`k0x|`ILH6%`uK)`SG(Q4p2;Xb~w1Z#6zN`Ve0Bz!@ zux}5bP5cP^w({L3KzsN(fPC#=19gD5@gsmeN5gmk9Rj?9HgG_|@&Y;m?E-MjKZhSS z@goL>0_YvGAA{+i!oIG*-3Rmh6F|EFwF<`nzJ&Vo{8mpeX{f0aFlQ(B6Ly3!te2n*Mgz1oR~P`M}{9a9;aq z928)0KV=3y0pC*QfTF<^;4wg70X)R}8TQ==AlUs2OTg~GrGEIG@4nR<3T*c6XSo6e z#=!k*Y@n#`&Dlax-%1iSyk^D2NN5pWF%PXT!JzpG9G z0R^6oZ&UyceV3d9fQi0~PJsss3;P56 zVdd6pr$nbd&G@cK)bn5USydo!7Riub$QHS!AvgfzeyC8Ql_B{uQP#*!d4bUzCOLIF znrG&U1HMD@b1?@q3#h3H?h4S%k1Bwy`wdvz+lhl;aJeAPhXCHK zfT7?>7y<#~L2$xQtS}fW+fPHNhm*Ml1Oz4!07HNO10XUF;I}QH|Ionk82p3UeW$@O zzz0`;rvWGgFq2>Np+T7ZJBR6aayLr=dZ3^;>;dV9N9x4GjaP zT))#0AiVmuEMNz}^9(~{{*V Date: Mon, 13 Apr 2026 13:01:13 -0400 Subject: [PATCH 3/7] Add cost model benchmarking --- .../analysis/__init__.py | 0 .../analysis/compute_predictions.py | 160 ++++++ .../analysis/fit_overhead.py | 226 ++++++++ .../analysis/fit_primitives.py | 255 +++++++++ .../benchmarks/__init__.py | 0 .../benchmarks/bench_compute.py | 232 ++++++++ .../benchmarks/bench_concurrency.py | 227 ++++++++ .../benchmarks/bench_end_to_end.py | 504 ++++++++++++++++++ .../benchmarks/bench_gather.py | 181 +++++++ .../benchmarks/bench_pingpong.py | 161 ++++++ .../benchmarks/common.py | 176 ++++++ .../cost_model_benchmarks/figures/compute.pdf | Bin 19308 -> 0 bytes .../run_local_compute_tests.sh | 12 + .../visualization/__init__.py | 0 .../visualization/plot_ablations.py | 160 ++++++ .../visualization/plot_compute.py | 172 ++++++ .../visualization/plot_gather.py | 120 +++++ .../visualization/plot_pingpong.py | 147 +++++ .../visualization/plot_tipping_point.py | 141 +++++ .../visualization/plot_validation.py | 127 +++++ 20 files changed, 3001 insertions(+) create mode 100644 experiments/cost_model_benchmarks/analysis/__init__.py create mode 100644 experiments/cost_model_benchmarks/analysis/compute_predictions.py create mode 100644 experiments/cost_model_benchmarks/analysis/fit_overhead.py create mode 100644 experiments/cost_model_benchmarks/analysis/fit_primitives.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/__init__.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/bench_compute.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/bench_gather.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/common.py delete mode 100644 experiments/cost_model_benchmarks/figures/compute.pdf create mode 100644 experiments/cost_model_benchmarks/run_local_compute_tests.sh create mode 100644 experiments/cost_model_benchmarks/visualization/__init__.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_ablations.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_compute.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_gather.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_pingpong.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_tipping_point.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_validation.py diff --git a/experiments/cost_model_benchmarks/analysis/__init__.py b/experiments/cost_model_benchmarks/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/cost_model_benchmarks/analysis/compute_predictions.py b/experiments/cost_model_benchmarks/analysis/compute_predictions.py new file mode 100644 index 0000000..3e69f4d --- /dev/null +++ b/experiments/cost_model_benchmarks/analysis/compute_predictions.py @@ -0,0 +1,160 @@ +"""Analysis — Apply Assembled Cost Model and Compare to Measurements. + +Reads: + - data/fitted_primitives.json + - data/fitted_overhead.json + - All end-to-end JSON run files + +For every run, computes the predicted T_layer: + + T_layer = T_comp + max(T_intra, T_inter) + T_buffer_copy + T_overhead + +and records the measured median, predicted value, absolute error, and relative +error (as a fraction). + +Also computes aggregate MAPE for the fit subset and the held-out subset +(using the same --fit-filter expression as fit_overhead.py). + +Outputs ``data/predictions.json``. + +Usage:: + + python -m analysis.compute_predictions \\ + --primitives data/fitted_primitives.json \\ + --overhead data/fitted_overhead.json \\ + --e2e-runs data/e2e_*.json \\ + --fit-filter "world_size <= 8" \\ + --output data/predictions.json +""" + +import argparse +import json +from pathlib import Path + +import numpy as np + +from analysis.fit_overhead import ( + apply_filter, + load_e2e_runs, + predict_layer_time, +) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Apply cost model and compute predictions") + p.add_argument("--primitives", type=str, required=True) + p.add_argument("--overhead", type=str, required=True) + p.add_argument("--e2e-runs", nargs="+", required=True, metavar="FILE") + p.add_argument("--fit-filter", type=str, default="world_size <= 8", + help="Same expression used when fitting overhead (determines train/test split)") + p.add_argument("--output", type=str, default="data/predictions.json") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.primitives) as f: + primitives = json.load(f) + with open(args.overhead) as f: + overhead_data = json.load(f) + + T_overhead = overhead_data.get("overhead_seconds", 0.0) + + all_runs = load_e2e_runs(args.e2e_runs) + fit_runs, held_runs = apply_filter(all_runs, args.fit_filter) + + fit_set = set(id(r) for r in fit_runs) + + prediction_entries = [] + for r in all_runs: + T_model_base = predict_layer_time(r["config"], r["per_rank_stats"], primitives) + T_pred = T_model_base + T_overhead + T_meas = r["measured_median"] + abs_err = abs(T_meas - T_pred) + rel_err = abs_err / T_meas if T_meas > 0 else float("nan") + + # Decompose prediction for ablation figures + net = primitives.get("network", {}) + intra_bytes = r["per_rank_stats"].get("c_intra_bytes", 0) + inter_bytes = r["per_rank_stats"].get("c_inter_bytes", 0) + + def net_time(nbytes, mode): + params = net.get(mode, None) + if params is None or nbytes == 0: + return 0.0 + return params.get("latency_seconds", 0.0) + nbytes / params.get("bandwidth_bytes_per_sec", 1e10) + + T_intra = net_time(intra_bytes, "intra") + T_inter = net_time(inter_bytes, "inter") + T_comm = max(T_intra, T_inter) + + F = r["config"]["feature_dim"] + send_bytes = r["per_rank_stats"].get("send_total", 0) * F * 4 + gath_params = primitives.get("gather", {}).get("clustered", {}).get("gather", None) + T_buffer_copy = 0.0 + if gath_params and send_bytes > 0: + B_g = gath_params.get("bandwidth_bytes_per_sec", 1e12) + T_buffer_copy = gath_params.get("intercept_seconds", 0.0) + send_bytes / B_g + + entry = { + "source_file": r["source_file"], + "config": r["config"], + "partition_stats": r["per_rank_stats"], + "measured_median_seconds": T_meas, + "predicted_seconds": T_pred, + "absolute_error_seconds": abs_err, + "relative_error": rel_err, + "in_fit_set": id(r) in fit_set, + "breakdown": { + "T_comp_seconds": T_model_base - T_comm - T_buffer_copy, + "T_comm_seconds": T_comm, + "T_buffer_copy_seconds": T_buffer_copy, + "T_overhead_seconds": T_overhead, + }, + } + prediction_entries.append(entry) + + # Aggregate MAPE + def mape(entries): + errs = [e["relative_error"] for e in entries + if not np.isnan(e["relative_error"])] + return float(np.mean(errs)) if errs else float("nan") + + fit_entries = [e for e in prediction_entries if e["in_fit_set"]] + held_entries = [e for e in prediction_entries if not e["in_fit_set"]] + + mape_fit = mape(fit_entries) + mape_held = mape(held_entries) + mape_total = mape(prediction_entries) + + print(f"[predictions] Fit MAPE={mape_fit*100:.2f}% " + f"Held-out MAPE={mape_held*100:.2f}% " + f"Total MAPE={mape_total*100:.2f}%") + + result = { + "fit_filter": args.fit_filter, + "T_overhead_seconds": T_overhead, + "aggregate": { + "mape_fit_set": mape_fit, + "mape_held_out": mape_held, + "mape_all": mape_total, + "num_fit": len(fit_entries), + "num_held_out": len(held_entries), + }, + "predictions": prediction_entries, + } + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(result, f, indent=2) + print(f"[predictions] Written to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/analysis/fit_overhead.py b/experiments/cost_model_benchmarks/analysis/fit_overhead.py new file mode 100644 index 0000000..426f440 --- /dev/null +++ b/experiments/cost_model_benchmarks/analysis/fit_overhead.py @@ -0,0 +1,226 @@ +"""Analysis — Fit Library Overhead Bias (T_overhead). + +Reads ``fitted_primitives.json`` and the small-K subset of end-to-end runs. +For each run it computes the model-predicted T_layer (without overhead), then +fits a single scalar T_overhead that minimises MAPE: + + MAPE = mean( |T_measured - (T_model + T_overhead)| / T_measured ) + +The subset used for fitting is controlled by ``--fit-filter``, which is +evaluated as a Python expression where each run's config fields are available +as local variables (e.g. ``"world_size <= 8"``). + +Outputs ``data/fitted_overhead.json``. + +Usage:: + + python -m analysis.fit_overhead \\ + --primitives data/fitted_primitives.json \\ + --e2e-runs data/e2e_*.json \\ + --fit-filter "world_size <= 8" \\ + --output data/fitted_overhead.json +""" + +import argparse +import json +from pathlib import Path + +import numpy as np + + +# --------------------------------------------------------------------------- +# Cost model (without overhead) +# --------------------------------------------------------------------------- + +def predict_layer_time(run_config: dict, per_rank_stats: dict, + primitives: dict) -> float: + """Predict T_layer for one rank using the assembled primitive model. + + T_layer = T_comp + max(T_intra, T_inter) + T_buffer_copy + + Parameters + ---------- + run_config : dict + Config block from the end-to-end JSON (feature_dim, model, etc.) + per_rank_stats : dict + Stats for rank 0 from per_rank_stats list. + primitives : dict + Loaded fitted_primitives.json. + """ + F = run_config["feature_dim"] + model_type = run_config.get("model", "gcn") + + n_local = per_rank_stats.get("n_local", 0) + n_halo = per_rank_stats.get("n_halo", 0) + n_total = n_local + n_halo + + # A rough edge count estimate: use avg_degree * n_local as a proxy + avg_degree = run_config.get("avg_degree", 20.0) + n_edges_local = int(n_local * avg_degree) + + # T_comp + comp_params = primitives.get("compute", {}).get(model_type, {}).get("forward", None) + if comp_params: + T_comp = (comp_params["coeff_V"] * n_total + + comp_params["coeff_E"] * n_edges_local + + comp_params["intercept"]) + T_comp = max(T_comp, 0.0) + else: + T_comp = 0.0 + + # T_intra and T_inter + intra_bytes = per_rank_stats.get("c_intra_bytes", 0) + inter_bytes = per_rank_stats.get("c_inter_bytes", 0) + + net = primitives.get("network", {}) + + def net_time(nbytes: int, mode: str) -> float: + params = net.get(mode, None) + if params is None or nbytes == 0: + return 0.0 + B = params.get("bandwidth_bytes_per_sec", 1e10) + t_L = params.get("latency_seconds", 0.0) + return t_L + nbytes / B + + T_intra = net_time(intra_bytes, "intra") + T_inter = net_time(inter_bytes, "inter") + T_comm = max(T_intra, T_inter) + + # T_buffer_copy (gather of send buffer) + send_bytes = per_rank_stats.get("send_total", 0) * F * 4 + gath_params = primitives.get("gather", {}).get("clustered", {}).get("gather", None) + if gath_params and send_bytes > 0: + B_g = gath_params.get("bandwidth_bytes_per_sec", 1e12) + T_buffer_copy = gath_params.get("intercept_seconds", 0.0) + send_bytes / B_g + else: + T_buffer_copy = 0.0 + + return T_comp + T_comm + T_buffer_copy + + +# --------------------------------------------------------------------------- +# Load helpers +# --------------------------------------------------------------------------- + +def load_e2e_runs(paths: list) -> list: + runs = [] + for p in paths: + with open(p) as f: + data = json.load(f) + config = data.get("config", {}) + for meas in data.get("measurements", []): + per_rank_stats = meas.get("per_rank_stats", [{}]) + rank0_stats = per_rank_stats[0] if per_rank_stats else {} + trials = meas.get("rank0_trials_seconds", []) + if not trials: + continue + runs.append({ + "config": config, + "per_rank_stats": rank0_stats, + "measured_median": float(np.median(trials)), + "source_file": str(p), + }) + return runs + + +def apply_filter(runs: list, filter_expr: str) -> tuple: + """Split runs into fit and held-out sets using filter_expr.""" + if not filter_expr: + return runs, [] + fit_runs, held_runs = [], [] + for r in runs: + env = dict(r["config"]) + env.update(r["per_rank_stats"]) + try: + if eval(filter_expr, {"__builtins__": {}}, env): + fit_runs.append(r) + else: + held_runs.append(r) + except Exception as e: + print(f"[fit_overhead] Warning: filter eval failed for run ({e}), including in fit set") + fit_runs.append(r) + return fit_runs, held_runs + + +# --------------------------------------------------------------------------- +# Scalar overhead fitting +# --------------------------------------------------------------------------- + +def fit_overhead_scalar(fit_runs: list, primitives: dict) -> tuple: + """Fit T_overhead to minimise MAPE on fit_runs. Returns (overhead, mape_in_sample).""" + if not fit_runs: + return 0.0, float("nan") + + residuals = [] + for r in fit_runs: + T_model = predict_layer_time(r["config"], r["per_rank_stats"], primitives) + residuals.append(r["measured_median"] - T_model) + + # Optimal scalar overhead that minimises sum of |err - overhead| / T_meas + # is the weighted median; for uniform weights it's just the median of residuals. + overhead = float(np.median(residuals)) + + mape = float(np.mean([ + abs(r["measured_median"] - (predict_layer_time(r["config"], r["per_rank_stats"], primitives) + overhead)) + / r["measured_median"] + for r in fit_runs + ])) + return overhead, mape + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Fit T_overhead from end-to-end runs") + p.add_argument("--primitives", type=str, required=True, + help="Path to fitted_primitives.json") + p.add_argument("--e2e-runs", nargs="+", required=True, metavar="FILE") + p.add_argument("--fit-filter", type=str, default="world_size <= 8", + help="Python expression evaluated per run; True → fit set") + p.add_argument("--output", type=str, default="data/fitted_overhead.json") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.primitives) as f: + primitives = json.load(f) + + all_runs = load_e2e_runs(args.e2e_runs) + print(f"[fit_overhead] Loaded {len(all_runs)} run(s)") + + fit_runs, held_runs = apply_filter(all_runs, args.fit_filter) + print(f"[fit_overhead] Fit set: {len(fit_runs)} Held-out: {len(held_runs)}") + + overhead, mape_in = fit_overhead_scalar(fit_runs, primitives) + print(f"[fit_overhead] T_overhead = {overhead*1e3:.3f} ms in-sample MAPE = {mape_in*100:.2f}%") + + result = { + "overhead_seconds": overhead, + "fit_filter": args.fit_filter, + "num_fit_points": len(fit_runs), + "num_held_out": len(held_runs), + "in_sample_mape": mape_in, + "fit_subset_runs": [ + { + "source_file": r["source_file"], + "world_size": r["config"].get("world_size"), + "feature_dim": r["config"].get("feature_dim"), + "measured_median_seconds": r["measured_median"], + } + for r in fit_runs + ], + } + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(result, f, indent=2) + print(f"[fit_overhead] Written to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/analysis/fit_primitives.py b/experiments/cost_model_benchmarks/analysis/fit_primitives.py new file mode 100644 index 0000000..02c9233 --- /dev/null +++ b/experiments/cost_model_benchmarks/analysis/fit_primitives.py @@ -0,0 +1,255 @@ +"""Analysis — Fit Primitive Cost-Model Parameters. + +Reads JSON outputs from benchmarks 1.1, 1.3, and 1.4, fits the following +parameters by linear regression on **medians** of per-trial times: + +* Network: T = t_L + bytes / B → fit (t_L, B) for intra and inter +* Compute: T = coeff_V * |V| + coeff_E * |E| + intercept + (separate fits for GCN and edge-conditioned models) +* Gather: T = intercept + bytes / B_gather + (separate fits for contiguous, clustered, random distributions) + +Writes ``data/fitted_primitives.json``. + +Usage:: + + python -m analysis.fit_primitives \\ + --pingpong-intra data/pingpong_intra_*.json \\ + --pingpong-inter data/pingpong_inter_*.json \\ + --compute-gcn data/compute_gcn_*.json \\ + --compute-edge data/compute_edge_*.json \\ + --gather-contiguous data/gather_contiguous_*.json \\ + --gather-clustered data/gather_clustered_*.json \\ + --gather-random data/gather_random_*.json \\ + --output data/fitted_primitives.json +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +from scipy import stats as sp_stats + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def load_json_files(paths: list) -> list: + records = [] + for p in paths: + with open(p) as f: + records.append(json.load(f)) + return records + + +def median_of_trials(trials: list) -> float: + return float(np.median(trials)) + + +def r_squared(y_true: np.ndarray, y_pred: np.ndarray) -> float: + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + return float(1.0 - ss_res / ss_tot) if ss_tot > 0 else 1.0 + + +def linear_fit(x: np.ndarray, y: np.ndarray): + """Fit y = slope * x + intercept via scipy linregress. Returns dict.""" + result = sp_stats.linregress(x, y) + y_pred = result.slope * x + result.intercept + r2 = r_squared(y, y_pred) + return { + "slope": float(result.slope), + "intercept": float(result.intercept), + "r_squared": r2, + } + + +# --------------------------------------------------------------------------- +# Network fit: T = t_L + bytes / B +# --------------------------------------------------------------------------- + +def fit_network(records: list) -> dict: + """Fit (t_L, B) from ping-pong records (one mode per call).""" + bytes_arr = [] + time_arr = [] + for rec in records: + for meas in rec["measurements"]: + nbytes = meas["params"]["message_bytes"] + t_med = median_of_trials(meas["trials_seconds"]) + bytes_arr.append(nbytes) + time_arr.append(t_med) + + bytes_arr = np.array(bytes_arr, dtype=float) + time_arr = np.array(time_arr, dtype=float) + + # T = t_L + bytes / B → T = intercept + slope * bytes + # so slope = 1/B, intercept = t_L + fit = linear_fit(bytes_arr, time_arr) + bandwidth = 1.0 / fit["slope"] if fit["slope"] > 0 else float("nan") + latency = fit["intercept"] + return { + "bandwidth_bytes_per_sec": bandwidth, + "latency_seconds": latency, + "r_squared": fit["r_squared"], + "_raw_slope": fit["slope"], + "_raw_intercept": fit["intercept"], + "_num_points": len(bytes_arr), + } + + +# --------------------------------------------------------------------------- +# Compute fit: T = coeff_V * |V| + coeff_E * |E| + intercept +# --------------------------------------------------------------------------- + +def fit_compute(records: list, timing_key: str = "forward_trials_seconds") -> dict: + """Fit compute cost as a function of |V| and |E|. + + Uses multiple linear regression: T = a * |V| + b * |E| + c + """ + V_arr, E_arr, T_arr = [], [], [] + for rec in records: + for meas in rec["measurements"]: + V_arr.append(meas["params"]["num_vertices"]) + E_arr.append(meas["params"]["num_edges"]) + T_arr.append(median_of_trials(meas[timing_key])) + + V_arr = np.array(V_arr, dtype=float) + E_arr = np.array(E_arr, dtype=float) + T_arr = np.array(T_arr, dtype=float) + + # Design matrix: [V, E, 1] + A = np.column_stack([V_arr, E_arr, np.ones_like(V_arr)]) + result, _, _, _ = np.linalg.lstsq(A, T_arr, rcond=None) + coeff_V, coeff_E, intercept = result + T_pred = A @ result + r2 = r_squared(T_arr, T_pred) + + return { + "coeff_V": float(coeff_V), + "coeff_E": float(coeff_E), + "intercept": float(intercept), + "r_squared": r2, + "_num_points": len(T_arr), + } + + +# --------------------------------------------------------------------------- +# Gather fit: T = intercept + bytes / B_gather +# --------------------------------------------------------------------------- + +def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict: + k_arr, T_arr, F_arr = [], [], [] + for rec in records: + F = rec["config"]["feature_dim"] + for meas in rec["measurements"]: + k = meas["params"]["k"] + t_med = median_of_trials(meas[timing_key]) + k_arr.append(k) + T_arr.append(t_med) + F_arr.append(F) + + k_arr = np.array(k_arr, dtype=float) + F_arr = np.array(F_arr, dtype=float) + T_arr = np.array(T_arr, dtype=float) + + bytes_arr = k_arr * F_arr * 4.0 # float32 + fit = linear_fit(bytes_arr, T_arr) + bandwidth = 1.0 / fit["slope"] if fit["slope"] > 0 else float("nan") + return { + "bandwidth_bytes_per_sec": bandwidth, + "intercept_seconds": fit["intercept"], + "r_squared": fit["r_squared"], + "_num_points": len(T_arr), + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Fit cost-model primitive parameters") + p.add_argument("--pingpong-intra", nargs="+", default=[], metavar="FILE") + p.add_argument("--pingpong-inter", nargs="+", default=[], metavar="FILE") + p.add_argument("--compute-gcn", nargs="+", default=[], metavar="FILE") + p.add_argument("--compute-edge", nargs="+", default=[], metavar="FILE") + p.add_argument("--gather-contiguous", nargs="+", default=[], metavar="FILE") + p.add_argument("--gather-clustered", nargs="+", default=[], metavar="FILE") + p.add_argument("--gather-random", nargs="+", default=[], metavar="FILE") + p.add_argument("--output", type=str, default="data/fitted_primitives.json") + return p.parse_args() + + +def main(): + args = parse_args() + result = {} + + # Network + net = {} + if args.pingpong_intra: + recs = load_json_files(args.pingpong_intra) + net["intra"] = fit_network(recs) + print(f"[network/intra] B={net['intra']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"t_L={net['intra']['latency_seconds']*1e6:.2f} µs " + f"R²={net['intra']['r_squared']:.4f}") + if args.pingpong_inter: + recs = load_json_files(args.pingpong_inter) + net["inter"] = fit_network(recs) + print(f"[network/inter] B={net['inter']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"t_L={net['inter']['latency_seconds']*1e6:.2f} µs " + f"R²={net['inter']['r_squared']:.4f}") + result["network"] = net + + # Compute + comp = {} + if args.compute_gcn: + recs = load_json_files(args.compute_gcn) + comp["gcn"] = { + "forward": fit_compute(recs, "forward_trials_seconds"), + "backward": fit_compute(recs, "backward_trials_seconds"), + } + print(f"[compute/gcn] coeff_V={comp['gcn']['forward']['coeff_V']:.3e} " + f"coeff_E={comp['gcn']['forward']['coeff_E']:.3e} " + f"R²={comp['gcn']['forward']['r_squared']:.4f}") + if args.compute_edge: + recs = load_json_files(args.compute_edge) + comp["edge"] = { + "forward": fit_compute(recs, "forward_trials_seconds"), + "backward": fit_compute(recs, "backward_trials_seconds"), + } + print(f"[compute/edge] coeff_V={comp['edge']['forward']['coeff_V']:.3e} " + f"coeff_E={comp['edge']['forward']['coeff_E']:.3e} " + f"R²={comp['edge']['forward']['r_squared']:.4f}") + result["compute"] = comp + + # Gather + gath = {} + for dist_name, files_attr in [ + ("contiguous", "gather_contiguous"), + ("clustered", "gather_clustered"), + ("random", "gather_random"), + ]: + files = getattr(args, files_attr.replace("-", "_")) + if files: + recs = load_json_files(files) + gath[dist_name] = { + "gather": fit_gather(recs, "gather_trials_seconds"), + "scatter_add": fit_gather(recs, "scatter_add_trials_seconds"), + } + print(f"[gather/{dist_name}] " + f"B_gather={gath[dist_name]['gather']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"R²={gath[dist_name]['gather']['r_squared']:.4f}") + result["gather"] = gath + + # Write + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(result, f, indent=2) + print(f"[fit_primitives] Written to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/__init__.py b/experiments/cost_model_benchmarks/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_compute.py b/experiments/cost_model_benchmarks/benchmarks/bench_compute.py new file mode 100644 index 0000000..89f5249 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_compute.py @@ -0,0 +1,232 @@ +"""Benchmark 1.3 — GNN Layer Compute Primitive. + +Single-GPU benchmark. Fits f_comp(|Ṽ|, |Ẽ|) for two message-function +variants: + +* ``gcn`` — GCN-like: φ(h_b) = W h_b (source-only linear transform) +* ``edge`` — Edge-conditioned: φ(h_b, h_a, e_ba) = MLP([h_b, h_a, e_ba]) + with a 2-layer MLP (hidden dim = feature_dim) + +Two sweep modes (controlled by ``--sweep``): + +* ``vertices`` — vary |V| with |E| fixed at ``--fixed-value`` +* ``edges`` — vary |E| with |V| fixed at ``--fixed-value`` + +Usage:: + + python -m benchmarks.bench_compute \\ + --model edge --sweep vertices \\ + --min 1000 --max 100000 --steps 15 \\ + --fixed-value 500000 --feature-dim 128 \\ + --warmup 10 --trials 50 \\ + --output data/compute_edge_vswp.json --seed 42 +""" + +import argparse + +import numpy as np +import torch +import torch.nn as nn + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + write_result, +) + + +# --------------------------------------------------------------------------- +# Synthetic graph generation +# --------------------------------------------------------------------------- + + +def erdos_renyi_edges( + num_vertices: int, num_edges: int, device: torch.device +) -> torch.Tensor: + """Return an edge index tensor of shape [2, num_edges] (random, with replacement).""" + src = torch.randint(0, num_vertices, (num_edges,), device=device) + dst = torch.randint(0, num_vertices, (num_edges,), device=device) + return torch.stack([src, dst], dim=0) + + +# --------------------------------------------------------------------------- +# GNN layers +# --------------------------------------------------------------------------- + + +class GCNLayer(nn.Module): + """GCN-like: aggregate neighbour source features with a linear transform.""" + + def __init__(self, feature_dim: int): + super().__init__() + self.linear = nn.Linear(feature_dim, feature_dim, bias=False) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: + # x: [V, F], edge_index: [2, E] + src, dst = edge_index[0], edge_index[1] + # Message: transform source features + msg = self.linear(x[src]) # [E, F] + # Aggregate: scatter-add to destination + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +class EdgeConditionedLayer(nn.Module): + """Edge-conditioned: φ(h_b, h_a, e_ba) = MLP([h_b, h_a, e_ba]).""" + + def __init__(self, feature_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(3 * feature_dim, feature_dim), + nn.ReLU(), + nn.Linear(feature_dim, feature_dim), + ) + + def forward( + self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor + ) -> torch.Tensor: + # x: [V, F], edge_index: [2, E], edge_attr: [E, F] + src, dst = edge_index[0], edge_index[1] + msg_input = torch.cat([x[src], x[dst], edge_attr], dim=-1) # [E, 3F] + msg = self.mlp(msg_input) # [E, F] + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="GNN compute primitive benchmark") + p.add_argument("--model", choices=["gcn", "edge"], required=True) + p.add_argument("--sweep", choices=["vertices", "edges"], required=True) + p.add_argument("--min", type=int, default=1_000, dest="sweep_min") + p.add_argument("--max", type=int, default=1_000_000, dest="sweep_max") + p.add_argument("--steps", type=int, default=15) + p.add_argument( + "--fixed-value", + type=int, + default=500_000, + help="Fixed |E| when sweeping vertices, or fixed |V| when sweeping edges", + ) + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + seed_everything(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + F = args.feature_dim + + # Build model + if args.model == "gcn": + model = GCNLayer(F).to(device) + else: + model = EdgeConditionedLayer(F).to(device) + + # Sweep points + sweep_vals = np.unique( + np.round( + np.logspace( + np.log10(args.sweep_min), + np.log10(args.sweep_max), + num=args.steps, + ) + ).astype(int) + ).tolist() + + measurements = [] + for val in sweep_vals: + if args.sweep == "vertices": + num_v, num_e = val, args.fixed_value + else: + num_v, num_e = args.fixed_value, val + + # Synthetic data + x = torch.randn(num_v, F, device=device, requires_grad=True) + edge_index = erdos_renyi_edges(num_v, num_e, device) + edge_attr = ( + torch.randn(num_e, F, device=device) if args.model == "edge" else None + ) + + # Forward timing + def fwd(): + if args.model == "gcn": + model(x, edge_index) + else: + model(x, edge_index, edge_attr) + + fwd_times = cuda_timed(fwd, warmup=args.warmup, trials=args.trials) + + # Backward timing (run fwd first to get a graph) + if args.model == "gcn": + out = model(x, edge_index) + else: + out = model(x, edge_index, edge_attr) + loss_ref = out.sum() + + def bwd(): + if x.grad is not None: + x.grad.zero_() + if args.model == "gcn": + out_inner = model(x, edge_index) + else: + out_inner = model(x, edge_index, edge_attr) + out_inner.sum().backward() + + bwd_times = cuda_timed(bwd, warmup=args.warmup, trials=args.trials) + + measurements.append( + { + "params": { + "num_vertices": num_v, + "num_edges": num_e, + "sweep_var": args.sweep, + "sweep_value": val, + "model": args.model, + "feature_dim": F, + }, + "forward_trials_seconds": fwd_times, + "backward_trials_seconds": bwd_times, + } + ) + med_fwd = sorted(fwd_times)[len(fwd_times) // 2] + med_bwd = sorted(bwd_times)[len(bwd_times) // 2] + print( + f"[compute/{args.model}] |V|={num_v:>8} |E|={num_e:>9} " + f"fwd {1e3*med_fwd:.2f} ms bwd {1e3*med_bwd:.2f} ms" + ) + + payload = { + "benchmark": "compute", + "metadata": collect_metadata(), + "config": { + "model": args.model, + "sweep": args.sweep, + "sweep_min": args.sweep_min, + "sweep_max": args.sweep_max, + "steps": args.steps, + "fixed_value": args.fixed_value, + "feature_dim": F, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + }, + "measurements": measurements, + } + write_result(args.output, payload) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py b/experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py new file mode 100644 index 0000000..8d28cfd --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py @@ -0,0 +1,227 @@ +"""Benchmark 1.2 — Intra/Inter Concurrency Check. + +Verifies that T_both / max(T_intra, T_inter) ≈ 1, i.e. that NVLink and +InfiniBand transfers overlap when issued simultaneously. + +Requires exactly 4 ranks across 2 nodes: + Node A: rank 0 (A0), rank 1 (A1) + Node B: rank 2 (B0), rank 3 (B1) + +Three conditions at a fixed message size: + 1. intra-only — A0↔A1 and B0↔B1 (no cross-node traffic) + 2. inter-only — A0↔B0 and A1↔B1 (no intra-node traffic) + 3. concurrent — all four exchanges in the same window (separate streams) + +Each rank logs its own wall time per trial. Rank 0 collects and writes JSON. + +Usage:: + + srun -N 2 --ntasks-per-node 2 python -m benchmarks.bench_concurrency \\ + --message-bytes 16777216 --warmup 20 --trials 100 \\ + --output data/concurrency.json --seed 42 +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +from benchmarks.common import ( + collect_metadata, + seed_everything, + setup_distributed, + write_result, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _exchange(tensor_send: torch.Tensor, tensor_recv: torch.Tensor, + peer: int, stream: torch.cuda.Stream) -> None: + """Non-blocking send+recv on *stream* with *peer*.""" + with torch.cuda.stream(stream): + send_op = dist.P2POp(dist.isend, tensor_send, peer) + recv_op = dist.P2POp(dist.irecv, tensor_recv, peer) + reqs = dist.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + + +def timed_window(rank: int, + send_buf: torch.Tensor, + recv_buf: torch.Tensor, + peer: int, + stream: torch.cuda.Stream, + warmup: int, + trials: int) -> list: + """Time a single exchange window (one send+recv pair).""" + for _ in range(warmup): + _exchange(send_buf, recv_buf, peer, stream) + torch.cuda.synchronize() + dist.barrier() + + times = [] + for _ in range(trials): + dist.barrier() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _exchange(send_buf, recv_buf, peer, stream) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1_000.0) + return times + + +def timed_concurrent(rank: int, + intra_send: torch.Tensor, intra_recv: torch.Tensor, intra_peer: int, + inter_send: torch.Tensor, inter_recv: torch.Tensor, inter_peer: int, + intra_stream: torch.cuda.Stream, + inter_stream: torch.cuda.Stream, + warmup: int, + trials: int) -> list: + """Time intra and inter exchanges issued concurrently on separate streams.""" + for _ in range(warmup): + _exchange(intra_send, intra_recv, intra_peer, intra_stream) + _exchange(inter_send, inter_recv, inter_peer, inter_stream) + torch.cuda.synchronize() + dist.barrier() + + times = [] + for _ in range(trials): + dist.barrier() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _exchange(intra_send, intra_recv, intra_peer, intra_stream) + _exchange(inter_send, inter_recv, inter_peer, inter_stream) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1_000.0) + return times + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Intra/inter concurrency benchmark") + p.add_argument("--message-bytes", type=int, default=16_777_216) # 16 MiB + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--trials", type=int, default=100) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + rank, world_size, local_rank = setup_distributed() + + if world_size != 4: + raise ValueError( + f"bench_concurrency requires exactly 4 ranks (got {world_size}).\n" + "Layout: rank 0,1 on node A; rank 2,3 on node B." + ) + + seed_everything(args.seed) + device = torch.device(f"cuda:{local_rank}") + + num_elems = max(1, args.message_bytes // 4) + send_buf = torch.randn(num_elems, dtype=torch.float32, device=device) + recv_buf = torch.zeros(num_elems, dtype=torch.float32, device=device) + + # Intra-node peers: 0↔1, 2↔3 + # Inter-node peers: 0↔2, 1↔3 + intra_peer = {0: 1, 1: 0, 2: 3, 3: 2}[rank] + inter_peer = {0: 2, 1: 3, 2: 0, 3: 1}[rank] + + intra_stream = torch.cuda.Stream(device=device) + inter_stream = torch.cuda.Stream(device=device) + + intra_send = send_buf.clone() + intra_recv = torch.zeros_like(recv_buf) + inter_send = send_buf.clone() + inter_recv = torch.zeros_like(recv_buf) + + # --- Condition 1: intra-only --- + times_intra = timed_window( + rank, intra_send, intra_recv, intra_peer, intra_stream, + args.warmup, args.trials + ) + dist.barrier() + + # --- Condition 2: inter-only --- + times_inter = timed_window( + rank, inter_send, inter_recv, inter_peer, inter_stream, + args.warmup, args.trials + ) + dist.barrier() + + # --- Condition 3: concurrent --- + times_concurrent = timed_concurrent( + rank, + intra_send, intra_recv, intra_peer, + inter_send, inter_recv, inter_peer, + intra_stream, inter_stream, + args.warmup, args.trials, + ) + dist.barrier() + + # Gather per-rank times to rank 0 + def gather_times(times_local): + obj = [None] * world_size + dist.all_gather_object(obj, times_local) + return obj + + intra_all = gather_times(times_intra) + inter_all = gather_times(times_inter) + conc_all = gather_times(times_concurrent) + + if rank == 0: + measurements = [ + { + "params": {"condition": "intra_only", "message_bytes": num_elems * 4}, + "per_rank_trials_seconds": intra_all, + }, + { + "params": {"condition": "inter_only", "message_bytes": num_elems * 4}, + "per_rank_trials_seconds": inter_all, + }, + { + "params": {"condition": "concurrent", "message_bytes": num_elems * 4}, + "per_rank_trials_seconds": conc_all, + }, + ] + payload = { + "benchmark": "concurrency", + "metadata": collect_metadata(), + "config": { + "message_bytes": args.message_bytes, + "warmup": args.warmup, + "trials": args.trials, + "world_size": world_size, + "seed": args.seed, + "rank_layout": "rank 0,1 on node A; rank 2,3 on node B", + }, + "measurements": measurements, + } + write_result(args.output, payload) + print( + f"[concurrency] intra median = " + f"{1e3*sorted(intra_all[0])[len(intra_all[0])//2]:.2f} ms | " + f"inter median = " + f"{1e3*sorted(inter_all[0])[len(inter_all[0])//2]:.2f} ms | " + f"concurrent median = " + f"{1e3*sorted(conc_all[0])[len(conc_all[0])//2]:.2f} ms" + ) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py b/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py new file mode 100644 index 0000000..7da40eb --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py @@ -0,0 +1,504 @@ +"""Benchmark 2.1 — End-to-End Halo Exchange. + +Measures full GNN layer wall time (forward + backward) across a sweep of +configurations on the full multi-node setup. Intended to be run as a SLURM +array job with one invocation per (K, F, graph) combination. + +This module contains a self-contained minimal halo-exchange implementation +(no dependency on the DGraph production library) so the benchmark remains +isolated and portable. + +Synthetic graphs: + * ``erdos_renyi`` — Erdős-Rényi with ``--avg-degree`` expected degree + * ``sbm`` — Stochastic Block Model; ``--sbm-inter-density`` controls + the fraction of inter-block edges (topology ablation) + +Partitioners: + * ``random`` — assign each vertex to a uniformly random rank + * ``balanced`` — contiguous vertex blocks of equal size + * ``metis`` — balanced k-way via pymetis (skipped if not installed) + +The benchmark logs, for every run: + * world_size K, feature dim F, graph type, partitioner + * per-rank partition statistics: + intra_halo_size — halo vertices on the same node + inter_halo_size — halo vertices on different nodes + c_intra, c_inter — communication volumes (bytes) + * per-trial layer times from rank 0 (and per-rank times for completeness) + +Usage:: + + torchrun --nnodes 2 --nproc_per_node 4 \\ + -m benchmarks.bench_end_to_end \\ + --graph erdos_renyi --num-vertices 100000 --avg-degree 20 \\ + --feature-dim 128 --model gcn --partitioner balanced \\ + --warmup 10 --trials 50 \\ + --output data/e2e_K8_F128_er_bal.json --seed 42 +""" + +import argparse +import os + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + setup_distributed, + write_result, +) + + +# =========================================================================== +# Synthetic graph generators +# =========================================================================== + +def gen_erdos_renyi(num_vertices: int, avg_degree: float, + rng: np.random.Generator) -> np.ndarray: + """Return edge array of shape [E, 2] (src, dst) for an Erdős-Rényi digraph.""" + num_edges = int(num_vertices * avg_degree) + src = rng.integers(0, num_vertices, size=num_edges) + dst = rng.integers(0, num_vertices, size=num_edges) + return np.stack([src, dst], axis=1) + + +def gen_sbm(num_vertices: int, avg_degree: float, inter_density: float, + rng: np.random.Generator) -> np.ndarray: + """Return edges for a Stochastic Block Model graph. + + Vertices are split into blocks of equal size (one per rank for convenience, + though the actual partitioning is a separate step). The ratio of + intra-block to inter-block edges is controlled by *inter_density*. + """ + world_size = dist.get_world_size() if dist.is_initialized() else 4 + block_size = num_vertices // world_size + edges = [] + target_edges = int(num_vertices * avg_degree) + + intra_edges = int(target_edges * (1.0 - inter_density)) + inter_edges = target_edges - intra_edges + + # Intra-block edges + for b in range(world_size): + start = b * block_size + end = start + block_size + n = intra_edges // world_size + s = rng.integers(start, end, size=n) + d = rng.integers(start, end, size=n) + edges.append(np.stack([s, d], axis=1)) + + # Inter-block edges + s = rng.integers(0, num_vertices, size=inter_edges) + d = rng.integers(0, num_vertices, size=inter_edges) + # Force cross-block by offsetting dst block + d_block = (s // block_size + 1 + rng.integers(0, world_size - 1, + size=inter_edges)) % world_size + d = d_block * block_size + rng.integers(0, block_size, size=inter_edges) + d = np.clip(d, 0, num_vertices - 1) + edges.append(np.stack([s, d], axis=1)) + + return np.concatenate(edges, axis=0) + + +# =========================================================================== +# Partitioners +# =========================================================================== + +def partition_random(num_vertices: int, world_size: int, + rng: np.random.Generator) -> np.ndarray: + return rng.integers(0, world_size, size=num_vertices).astype(np.int64) + + +def partition_balanced(num_vertices: int, world_size: int) -> np.ndarray: + return np.floor(np.arange(num_vertices) * world_size / num_vertices).astype(np.int64) + + +def partition_metis(num_vertices: int, world_size: int, + edges: np.ndarray) -> np.ndarray: + try: + import pymetis + except ImportError: + raise RuntimeError( + "pymetis is not installed. Install it with: pip install pymetis\n" + "Or use --partitioner random or --partitioner balanced." + ) + # Build adjacency list for pymetis + adj = [[] for _ in range(num_vertices)] + for s, d in edges: + adj[s].append(int(d)) + adj[d].append(int(s)) + _, membership = pymetis.part_graph(world_size, adjacency=adj) + return np.array(membership, dtype=np.int64) + + +# =========================================================================== +# Minimal halo-exchange infrastructure +# =========================================================================== + +def build_local_comm_pattern(edges: np.ndarray, assignment: np.ndarray, + rank: int, world_size: int): + """Compute the local communication pattern for this rank. + + Returns a dict with: + local_vertices — np.ndarray of vertex IDs owned by this rank + local_edge_index — torch.Tensor [2, E_local] with local vertex IDs + remapped so that 0..n_local-1 are owned vertices + and n_local..n_local+n_halo-1 are halo vertices + send_counts — list[int] of length world_size: vertices to send + recv_counts — list[int] of length world_size: vertices to recv + send_idx — local indices (into local_vertices) to send per rank + halo_global_ids — global vertex IDs of halo vertices, in recv order + intra_halo_size — halo vertices from same node (ranks sharing node) + inter_halo_size — halo vertices from remote nodes + ranks_per_node — int (derived from LOCAL_RANK / RANK relationship) + """ + local_mask = assignment == rank + local_vertices = np.where(local_mask)[0] + n_local = len(local_vertices) + + # Global -> local index map + g2l = {int(v): i for i, v in enumerate(local_vertices)} + + # Find edges where dst is local + local_dst_mask = np.isin(edges[:, 1], local_vertices) + local_edges = edges[local_dst_mask] + + # Halo: src vertices not owned by this rank + halo_src_mask = ~np.isin(local_edges[:, 0], local_vertices) + halo_global = np.unique(local_edges[halo_src_mask, 0]) + + # Group halo vertices by owning rank + halo_owners = assignment[halo_global] + recv_by_rank = [] + halo_order = [] + for r in range(world_size): + verts = halo_global[halo_owners == r] + recv_by_rank.append(verts) + halo_order.extend(verts.tolist()) + halo_order = np.array(halo_order, dtype=np.int64) + + # Global halo id -> local halo index + halo_g2l = {int(v): n_local + i for i, v in enumerate(halo_order)} + all_g2l = {**g2l, **halo_g2l} + + # Find which local vertices other ranks need (send pattern) + # We exchange recv_counts via all_to_all to learn send_counts + recv_counts = [len(rv) for rv in recv_by_rank] + + # Build send: for each rank r, which of our local vertices does r need? + # We do a global exchange of halo_global per rank + all_recv = [None] * world_size + dist.all_gather_object(all_recv, halo_order.tolist()) + + send_idx_by_rank = [] + for r in range(world_size): + needed = np.array(all_recv[r], dtype=np.int64) + owned_mask = assignment[needed] == rank if len(needed) > 0 else np.array([], dtype=bool) + owned = needed[owned_mask] if len(needed) > 0 else np.array([], dtype=np.int64) + # Map to local indices + local_idxs = np.array([g2l[int(v)] for v in owned], dtype=np.int64) + send_idx_by_rank.append(local_idxs) + + send_counts = [len(s) for s in send_idx_by_rank] + + # Remap edges to local indices + valid_edge_mask = np.array([ + (int(s) in all_g2l) and (int(d) in all_g2l) + for s, d in local_edges + ]) + local_edges_valid = local_edges[valid_edge_mask] + if len(local_edges_valid) > 0: + remapped_src = np.array([all_g2l[int(s)] for s in local_edges_valid[:, 0]]) + remapped_dst = np.array([all_g2l[int(d)] for d in local_edges_valid[:, 1]]) + edge_index = torch.tensor( + np.stack([remapped_src, remapped_dst], axis=0), dtype=torch.long + ) + else: + edge_index = torch.zeros((2, 0), dtype=torch.long) + + # Compute intra / inter halo sizes + ranks_per_node = int(os.environ.get("LOCAL_WORLD_SIZE", + os.environ.get("SLURM_NTASKS_PER_NODE", "4"))) + my_node = rank // ranks_per_node + intra_halo_size = 0 + inter_halo_size = 0 + for r, verts in enumerate(recv_by_rank): + peer_node = r // ranks_per_node + if peer_node == my_node: + intra_halo_size += len(verts) + else: + inter_halo_size += len(verts) + + return { + "local_vertices": local_vertices, + "n_local": n_local, + "n_halo": len(halo_order), + "edge_index": edge_index, + "send_counts": send_counts, + "recv_counts": recv_counts, + "send_idx_by_rank": send_idx_by_rank, + "halo_order": halo_order, + "intra_halo_size": intra_halo_size, + "inter_halo_size": inter_halo_size, + "ranks_per_node": ranks_per_node, + } + + +class MinimalHaloExchange(torch.autograd.Function): + """Forward: gather boundary features → all_to_all → populate recv buffer. + Backward: reverse the transfer to accumulate gradients. + """ + + @staticmethod + def forward(ctx, x_local, send_idx_flat, send_counts, recv_counts, world_size): + # Gather send buffer + send_buf = x_local[send_idx_flat] # [total_send, F] + + # Split by destination rank + send_list = list(send_buf.split(send_counts, dim=0)) + recv_list = [torch.zeros(rc, x_local.shape[1], + dtype=x_local.dtype, device=x_local.device) + for rc in recv_counts] + + dist.all_to_all(recv_list, send_list) + + recv_buf = torch.cat(recv_list, dim=0) if sum(recv_counts) > 0 else \ + torch.zeros(0, x_local.shape[1], device=x_local.device) + + ctx.save_for_backward(send_idx_flat) + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + ctx.world_size = world_size + ctx.n_local = x_local.shape[0] + ctx.feature_dim = x_local.shape[1] + ctx.device = x_local.device + + return recv_buf + + @staticmethod + def backward(ctx, grad_recv): + send_idx_flat, = ctx.saved_tensors + send_counts = ctx.send_counts + recv_counts = ctx.recv_counts + world_size = ctx.world_size + n_local = ctx.n_local + F = ctx.feature_dim + device = ctx.device + + # Reverse: recv_counts become send_counts and vice versa + grad_recv_list = list(grad_recv.split(recv_counts, dim=0)) \ + if grad_recv.shape[0] > 0 else [torch.zeros(0, F, device=device)] * world_size + grad_send_list = [torch.zeros(sc, F, device=device) for sc in send_counts] + + dist.all_to_all(grad_send_list, grad_recv_list) + + grad_send = torch.cat(grad_send_list, dim=0) + + # Scatter-add back to local vertices + grad_x_local = torch.zeros(n_local, F, device=device, dtype=grad_recv.dtype) + grad_x_local.scatter_add_( + 0, + send_idx_flat.unsqueeze(1).expand_as(grad_send), + grad_send, + ) + return grad_x_local, None, None, None, None + + +# =========================================================================== +# GNN layers (same as bench_compute.py for consistency) +# =========================================================================== + +class GCNLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.linear = nn.Linear(feature_dim, feature_dim, bias=False) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: + src, dst = edge_index[0], edge_index[1] + n_local = x.shape[0] + # Only update local vertices (dst < n_local guard not needed since + # edge_index already restricts to local dst) + msg = self.linear(x[src]) + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +class EdgeConditionedLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(3 * feature_dim, feature_dim), + nn.ReLU(), + nn.Linear(feature_dim, feature_dim), + ) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor, + edge_attr: torch.Tensor) -> torch.Tensor: + src, dst = edge_index[0], edge_index[1] + msg = self.mlp(torch.cat([x[src], x[dst], edge_attr], dim=-1)) + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +# =========================================================================== +# Main +# =========================================================================== + +def parse_args(): + p = argparse.ArgumentParser(description="End-to-end halo exchange benchmark") + p.add_argument("--graph", choices=["erdos_renyi", "sbm"], default="erdos_renyi") + p.add_argument("--num-vertices", type=int, default=100_000) + p.add_argument("--avg-degree", type=float, default=20.0) + p.add_argument("--sbm-inter-density", type=float, default=0.1, + help="Fraction of inter-block edges for SBM graphs") + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--model", choices=["gcn", "edge"], default="gcn") + p.add_argument("--partitioner", choices=["random", "balanced", "metis"], + default="balanced") + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + rank, world_size, local_rank = setup_distributed() + seed_everything(args.seed + rank) # per-rank seed for graph generation + rng = np.random.default_rng(args.seed) # shared seed for graph topology + device = torch.device(f"cuda:{local_rank}") + F = args.feature_dim + + # --- Generate graph on all ranks (same seed → identical graph) --- + if args.graph == "erdos_renyi": + edges = gen_erdos_renyi(args.num_vertices, args.avg_degree, rng) + else: + edges = gen_sbm(args.num_vertices, args.avg_degree, + args.sbm_inter_density, rng) + + # --- Partition --- + rng_part = np.random.default_rng(args.seed + 1) + if args.partitioner == "random": + assignment = partition_random(args.num_vertices, world_size, rng_part) + elif args.partitioner == "balanced": + assignment = partition_balanced(args.num_vertices, world_size) + else: + assignment = partition_metis(args.num_vertices, world_size, edges) + + # --- Build local comm pattern --- + pattern = build_local_comm_pattern(edges, assignment, rank, world_size) + n_local = pattern["n_local"] + n_halo = pattern["n_halo"] + edge_index = pattern["edge_index"].to(device) + + send_counts = pattern["send_counts"] + recv_counts = pattern["recv_counts"] + send_idx_flat = torch.cat([ + torch.tensor(s, dtype=torch.long) for s in pattern["send_idx_by_rank"] + ]).to(device) if sum(send_counts) > 0 else torch.zeros(0, dtype=torch.long, device=device) + + # --- Model --- + if args.model == "gcn": + layer = GCNLayer(F).to(device) + else: + layer = EdgeConditionedLayer(F).to(device) + layer.train() + + # --- Synthetic local node features --- + x_local = torch.randn(n_local, F, device=device, requires_grad=True) + edge_attr = torch.randn(edge_index.shape[1], F, device=device) \ + if args.model == "edge" else None + + # --- Timed forward + backward --- + def one_layer(): + # Forward halo exchange + recv_buf = MinimalHaloExchange.apply( + x_local, send_idx_flat, send_counts, recv_counts, world_size + ) + # Augment: local + halo + x_aug = torch.cat([x_local, recv_buf], dim=0) + # Message passing + if args.model == "gcn": + out = layer(x_aug, edge_index) + else: + out = layer(x_aug, edge_index, edge_attr) + # Backward + loss = out.sum() + loss.backward() + if x_local.grad is not None: + x_local.grad.zero_() + + # Barrier before timing + dist.barrier() + times_local = cuda_timed(one_layer, warmup=args.warmup, trials=args.trials) + dist.barrier() + + # Gather per-rank times and stats to rank 0 + stats_local = { + "rank": rank, + "n_local": n_local, + "n_halo": n_halo, + "intra_halo_size": pattern["intra_halo_size"], + "inter_halo_size": pattern["inter_halo_size"], + "c_intra_bytes": pattern["intra_halo_size"] * F * 4, + "c_inter_bytes": pattern["inter_halo_size"] * F * 4, + "send_total": sum(send_counts), + "recv_total": sum(recv_counts), + "trials_seconds": times_local, + } + + all_stats = [None] * world_size + dist.all_gather_object(all_stats, stats_local) + + if rank == 0: + med = sorted(times_local)[len(times_local) // 2] + print( + f"[e2e] K={world_size} F={F} {args.graph}/{args.partitioner}/{args.model} " + f"n_local={n_local} n_halo={n_halo} " + f"median {1e3*med:.2f} ms" + ) + payload = { + "benchmark": "end_to_end", + "metadata": collect_metadata(), + "config": { + "graph": args.graph, + "num_vertices": args.num_vertices, + "avg_degree": args.avg_degree, + "sbm_inter_density": args.sbm_inter_density, + "feature_dim": F, + "model": args.model, + "partitioner": args.partitioner, + "world_size": world_size, + "ranks_per_node": pattern["ranks_per_node"], + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + }, + "measurements": [ + { + "params": { + "world_size": world_size, + "feature_dim": F, + "graph": args.graph, + "partitioner": args.partitioner, + "model": args.model, + }, + "rank0_trials_seconds": times_local, + "per_rank_stats": all_stats, + } + ], + } + write_result(args.output, payload) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_gather.py b/experiments/cost_model_benchmarks/benchmarks/bench_gather.py new file mode 100644 index 0000000..d62c71b --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_gather.py @@ -0,0 +1,181 @@ +"""Benchmark 1.4 — Buffer-Copy / Gather-Scatter Bandwidth. + +Single-GPU benchmark. Measures the effective bandwidth of ``x[idx]`` (gather) +and the corresponding ``scatter_add_`` (backward) for three index distributions: + +* ``contiguous`` — contiguous block starting at a random offset +* ``clustered`` — *c* cluster centres, each with a contiguous block of size + ``--cluster-size``; simulates a well-partitioned graph halo +* ``random`` — uniformly random indices (worst-case for cache) + +Sweeps *k* (number of gathered rows) from ``--min-k`` to ``--max-k``. + +Usage:: + + python -m benchmarks.bench_gather \\ + --distribution clustered \\ + --min-k 1000 --max-k 10000000 --steps 20 \\ + --N 20000000 --feature-dim 128 --cluster-size 64 \\ + --warmup 10 --trials 50 \\ + --output data/gather_clustered.json --seed 42 +""" + +import argparse + +import numpy as np +import torch + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + write_result, +) + + +# --------------------------------------------------------------------------- +# Index generators +# --------------------------------------------------------------------------- + +def contiguous_idx(N: int, k: int, device: torch.device) -> torch.Tensor: + start = torch.randint(0, max(1, N - k), (1,)).item() + return torch.arange(start, start + k, device=device) + + +def clustered_idx(N: int, k: int, cluster_size: int, + device: torch.device, rng: np.random.Generator) -> torch.Tensor: + """Draw cluster centres, then take a contiguous block around each.""" + num_clusters = max(1, k // cluster_size) + centres = rng.integers(0, N, size=num_clusters) + idx_parts = [] + for c in centres: + start = int(np.clip(c, 0, N - cluster_size)) + idx_parts.append(torch.arange(start, start + cluster_size, device=device)) + idx = torch.cat(idx_parts)[:k] + return idx + + +def random_idx(N: int, k: int, device: torch.device) -> torch.Tensor: + return torch.randperm(N, device=device)[:k] + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Gather/scatter bandwidth benchmark") + p.add_argument("--distribution", choices=["contiguous", "clustered", "random"], + required=True) + p.add_argument("--min-k", type=int, default=1_000) + p.add_argument("--max-k", type=int, default=10_000_000) + p.add_argument("--steps", type=int, default=20) + p.add_argument("--N", type=int, default=20_000_000, + help="Total number of rows in the source tensor x") + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--cluster-size", type=int, default=64, + help="Rows per cluster (only used with --distribution clustered)") + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + seed_everything(args.seed) + rng = np.random.default_rng(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + F = args.feature_dim + + # Pre-allocate the source tensor and gradient buffer once + x = torch.randn(args.N, F, device=device) + grad_y_template = torch.ones(1, F, device=device) # resized per k + + # Sweep k values + k_values = np.unique( + np.round( + np.logspace( + np.log10(args.min_k), + np.log10(args.max_k), + num=args.steps, + ) + ).astype(int) + ).tolist() + k_values = [min(int(k), args.N) for k in k_values] + + measurements = [] + for k in k_values: + # Build index tensor + if args.distribution == "contiguous": + idx = contiguous_idx(args.N, k, device) + elif args.distribution == "clustered": + idx = clustered_idx(args.N, k, args.cluster_size, device, rng) + else: + idx = random_idx(args.N, k, device) + + # Expand idx for scatter_add_: shape [k, F] + idx_expanded = idx.unsqueeze(1).expand(-1, F) + grad_y = torch.ones(k, F, device=device) + grad_x = torch.zeros_like(x) + + # --- Forward gather --- + def gather_fn(): + _ = x[idx] + + gather_times = cuda_timed(gather_fn, warmup=args.warmup, trials=args.trials) + + # --- Backward scatter-add --- + def scatter_fn(): + grad_x.zero_() + grad_x.scatter_add_(0, idx_expanded, grad_y) + + scatter_times = cuda_timed(scatter_fn, warmup=args.warmup, trials=args.trials) + + measurements.append({ + "params": { + "k": k, + "N": args.N, + "feature_dim": F, + "distribution": args.distribution, + "cluster_size": args.cluster_size if args.distribution == "clustered" else None, + }, + "gather_trials_seconds": gather_times, + "scatter_add_trials_seconds": scatter_times, + }) + + med_g = sorted(gather_times)[len(gather_times) // 2] + med_s = sorted(scatter_times)[len(scatter_times) // 2] + bytes_moved = k * F * 4 + bw_g = bytes_moved / med_g / 1e9 + bw_s = bytes_moved / med_s / 1e9 + print( + f"[gather/{args.distribution}] k={k:>9} " + f"gather {1e3*med_g:.2f} ms ({bw_g:.1f} GB/s) " + f"scatter {1e3*med_s:.2f} ms ({bw_s:.1f} GB/s)" + ) + + payload = { + "benchmark": "gather", + "metadata": collect_metadata(), + "config": { + "distribution": args.distribution, + "min_k": args.min_k, + "max_k": args.max_k, + "steps": args.steps, + "N": args.N, + "feature_dim": F, + "cluster_size": args.cluster_size, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + }, + "measurements": measurements, + } + write_result(args.output, payload) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py b/experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py new file mode 100644 index 0000000..c844311 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py @@ -0,0 +1,161 @@ +"""Benchmark 1.1 — Network Bandwidth / Latency (Ping-Pong). + +Measures one-way transfer time across a sweep of message sizes using a +two-rank ping-pong pattern. Run once with both ranks on the same node +(intra-node, NVLink) and once with ranks on different nodes (inter-node, +InfiniBand). The SLURM script controls placement; this script only records +a --mode label. + +Usage (via torchrun / srun):: + + torchrun --nnodes 1 --nproc_per_node 2 \\ + -m benchmarks.bench_pingpong \\ + --mode intra --min-bytes 64 --max-bytes 67108864 --steps 21 \\ + --warmup 20 --trials 100 --output data/pingpong_intra.json --seed 42 +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +from benchmarks.common import ( + collect_metadata, + seed_everything, + setup_distributed, + write_result, +) + + +# --------------------------------------------------------------------------- +# Ping-pong timing +# --------------------------------------------------------------------------- + +def pingpong_timed(rank: int, tensor: torch.Tensor, warmup: int, trials: int) -> list: + """Perform a ping-pong between rank 0 and rank 1. + + Returns per-trial *one-way* transfer times in seconds (rank 0 only). + Rank 1 returns an empty list. + """ + # Warmup + for _ in range(warmup): + if rank == 0: + dist.send(tensor, dst=1) + dist.recv(tensor, src=1) + else: + dist.recv(tensor, src=0) + dist.send(tensor, dst=0) + torch.cuda.synchronize() + dist.barrier() + + times = [] + for _ in range(trials): + dist.barrier() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + if rank == 0: + start_evt.record() + dist.send(tensor, dst=1) + dist.recv(tensor, src=1) + end_evt.record() + torch.cuda.synchronize() + # Round-trip / 2 = one-way + times.append(start_evt.elapsed_time(end_evt) / 2.0 / 1_000.0) + else: + dist.recv(tensor, src=0) + dist.send(tensor, dst=0) + + return times + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Ping-pong bandwidth/latency benchmark") + p.add_argument("--min-bytes", type=int, default=64) + p.add_argument("--max-bytes", type=int, default=67_108_864) # 64 MiB + p.add_argument("--steps", type=int, default=21, + help="Number of logarithmically-spaced message sizes") + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--trials", type=int, default=100) + p.add_argument("--mode", choices=["intra", "inter"], default="inter", + help="Label only — actual placement is controlled by SLURM") + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + rank, world_size, local_rank = setup_distributed() + + if world_size != 2: + raise ValueError(f"bench_pingpong requires exactly 2 ranks, got {world_size}") + + seed_everything(args.seed) + device = torch.device(f"cuda:{local_rank}") + + # Build logarithmically-spaced byte sizes (powers-of-2 friendly) + import numpy as np + byte_sizes = np.unique( + np.round( + np.logspace( + np.log2(args.min_bytes), + np.log2(args.max_bytes), + num=args.steps, + base=2, + ) + ).astype(int) + ).tolist() + + measurements = [] + for nbytes in byte_sizes: + # Float32 elements + num_elems = max(1, nbytes // 4) + tensor = torch.zeros(num_elems, dtype=torch.float32, device=device) + + times = pingpong_timed(rank, tensor, args.warmup, args.trials) + + if rank == 0: + measurements.append({ + "params": { + "message_bytes": num_elems * 4, + "num_elements": num_elems, + "mode": args.mode, + }, + "trials_seconds": times, + }) + print( + f"[pingpong] {num_elems * 4:>10} bytes | " + f"median {1e3 * float(sorted(times)[len(times)//2]):.3f} ms" + ) + + dist.barrier() + + if rank == 0: + payload = { + "benchmark": "pingpong", + "metadata": collect_metadata(), + "config": { + "min_bytes": args.min_bytes, + "max_bytes": args.max_bytes, + "steps": args.steps, + "warmup": args.warmup, + "trials": args.trials, + "mode": args.mode, + "world_size": world_size, + "seed": args.seed, + }, + "measurements": measurements, + } + write_result(args.output, payload) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/common.py b/experiments/cost_model_benchmarks/benchmarks/common.py new file mode 100644 index 0000000..27e8069 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/common.py @@ -0,0 +1,176 @@ +"""Shared utilities for cost-model benchmarks: timing, logging, metadata.""" + +import json +import os +import random +import socket +import subprocess +import time +from pathlib import Path +from typing import Callable + +import numpy as np +import torch +import torch.distributed as dist + + +# --------------------------------------------------------------------------- +# Timing +# --------------------------------------------------------------------------- + + +def cuda_timed(fn: Callable, warmup: int = 10, trials: int = 50) -> list: + """Run *fn* with CUDA-event timing. Returns per-trial wall times in seconds. + + The function is invoked with no arguments. Callers should capture any + needed state via closure. Warmup iterations are discarded. + """ + # Warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(trials): + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + fn() + end_evt.record() + torch.cuda.synchronize() + # elapsed_time returns milliseconds + times.append(start_evt.elapsed_time(end_evt) / 1_000.0) + return times + + +# --------------------------------------------------------------------------- +# Metadata collection +# --------------------------------------------------------------------------- + + +def collect_metadata() -> dict: + """Return a dict of reproducibility metadata for the current run.""" + meta: dict = { + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + "hostname": socket.gethostname(), + } + + # GPU info + if torch.cuda.is_available(): + gpus = [] + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + gpu_entry = { + "index": i, + "name": props.name, + "compute_capability": f"{props.major}.{props.minor}", + "total_memory_bytes": props.total_memory, + } + # UUID available in newer PyTorch builds + if hasattr(props, "uuid"): + gpu_entry["uuid"] = str(props.uuid) + gpus.append(gpu_entry) + meta["gpus"] = gpus + meta["cuda_version"] = torch.version.cuda + else: + meta["gpus"] = [] + meta["cuda_version"] = None + + meta["pytorch_version"] = torch.__version__ + + # NCCL version (tuple -> string) + try: + nccl_ver = torch.cuda.nccl.version() + meta["nccl_version"] = ".".join(str(x) for x in nccl_ver) + except Exception: + meta["nccl_version"] = "unknown" + + # SLURM environment variables + slurm_keys = [ + "SLURM_JOB_ID", + "SLURM_NODELIST", + "SLURM_NNODES", + "SLURM_NTASKS", + "SLURM_PROCID", + "SLURM_LOCALID", + "SLURM_ARRAY_JOB_ID", + "SLURM_ARRAY_TASK_ID", + ] + meta["slurm"] = {k: os.environ.get(k) for k in slurm_keys} + + # Git commit hash of the benchmark code + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + meta["git_commit"] = result.stdout.strip() + except Exception: + meta["git_commit"] = "unknown" + + return meta + + +# --------------------------------------------------------------------------- +# JSON output +# --------------------------------------------------------------------------- + + +def write_result(path: str, payload: dict) -> None: + """Write *payload* as a JSON file at *path*, creating parents as needed. + + Expected schema:: + + { + "benchmark": "", + "metadata": { ... }, + "config": { ... }, + "measurements": [ + {"params": {...}, "trials_seconds": [t1, t2, ...]}, + ... + ] + } + """ + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + with open(p, "w") as fh: + json.dump(payload, fh, indent=2) + print(f"[write_result] Saved {p} ({p.stat().st_size} bytes)") + + +# --------------------------------------------------------------------------- +# Distributed setup +# --------------------------------------------------------------------------- + + +def setup_distributed() -> tuple[int, int, int]: + """Initialize torch.distributed with the env-var init method (NCCL). + + Expects MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK, and LOCAL_RANK to be + set in the environment (standard for torchrun / SLURM + srun). + + Returns: + (rank, world_size, local_rank) + """ + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + return rank, world_size, local_rank + + +# --------------------------------------------------------------------------- +# Seeding +# --------------------------------------------------------------------------- + + +def seed_everything(seed: int) -> None: + """Set Python, NumPy, and PyTorch (CPU + CUDA) random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/experiments/cost_model_benchmarks/figures/compute.pdf b/experiments/cost_model_benchmarks/figures/compute.pdf deleted file mode 100644 index f60f93169802441205acbaa7ec6f0d27bcad62cb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 19308 zcmdtK1yof{^f#;$%9T(`QNRlb2okrrmrF@^ch@DAZc!;gK)OLfDFp-pDN#VWQ)vMy zB?J*fN|g5;)aNnye~;g{zV)tmz1%g-oP8$t?7h$I-^`pj%*x`DoNz7_ggJi<`tT_P z4uwJOO)oaEO zMeb+d)>bmTY>sn*V!mBUc(_QZxtO?^Ls8!d$|f!@=1z7{9`F*vt!8CnW^HE)MSs8R zWRFubcY*2w&5BC`teAVaK)GdX0Tx8R$70`O8K}We?7;u*0LVAUy^}e>-8c5RRn494 zU7c`1e?b0$_7%*{tW8AiJ%AQrz>fz8L7KVl*;BYPs5{}^k+Ts=idI?nJ z48?xqRm9HD9=L*o>mbK598&bx+sVVd6EXpG4!Kz7ugAYvAy$^PjHX+{G};5BgGh| zIJT+RHp+xq%~z=!>K{YxnMIS-rHPEUbpxmu&&|d6T0M*ofj7mvCbnIBqrCH_VfWI< z&5s{@NVNR5)0{qXzr=LB)(wE|aW8J(_U^pn)C ziGeKBorDhNMv9Zhv!ShLsFL0goS9BoklKBocIoA$1NAB4sM*wwpki6as2h8Bo~zM% z(oN#cJ={hLJ>F8zcZ|NeKFzfJ*a39OrM*ojc5rva-7|m<{{dc&*}&S+jzMAln(Z^n z&(m+#2N)NKm=w&^HyD)LD16tOc48M{qgf1=YEkPR_4`FY~pHu>QUVu zcU9cBOFW2UsrARR8T^u`I}5N$BJSZ z>dH@OOg7Y!G7htf4|2XdCHkd*I{LAY^^9yPq?he|qE+IPNR3b|gTxS@#yzDmqCS`- z(cWSvCzJX180&JRMGEY^;`-=_YW5KqS$>u?EL{BNcq7`2>b^pbtl~3`=H^wa>F3o5 z_?qY_NGPPB6>Qc9j9igIS0AY35fH{@qI0?2F21<B1@spdO}2*(4); z@mg7GlQaXozBCV6amcu?2f};=X(}WXR;Wk$Z6m4|eb#Gslubu^J6`u3j*%(0{yS^O=-+W#nx1AqNcD;9)wpw$`(UKIiZKG;;bRyJvq9#Pb z8%>JyLqC68e3RaNNm;RAJF)gHrRrICUw6V96OJhR?Qk)rNyk3cWwzP%cl3|@wm$HO zMPaulN9VP+qUM+?wrpyS1kxUfdrCk>>xOfDP*z*5OJFp>KGJOcDeLf?&DSg8vON1AEHy7hVHS8NOs@&9}haKIoRTHcXUOXLO*oIN&E*H7wbo<$Z z&-fD$wGf)sB(H_5sV|w(rWd``TQ#hXE?%_Vl3XN#c)t_=rI;5&d;Nu1Rh~zoVVl)Q z9zIR`MlK)u{?A|kDMh&0#0S+OPX4@60A|wuguec2Fx1`r?&{ajlouTn>}NPIU28K_ zwgImNXUU#i-B!@Y8sA!qkc+s^Cjo0q?CcRdRP)Z>14Z3Ay3Ge(m#%w6s} zmXw9B(D}vKp=qNUAust{bE)LzQSuWtQ2J*BCmtlNkJH~;kNmO|wiL&l zZ`yV_V7I-SilvD|`^5|#tHspfqK_>~U9~kwNK!I0ZKW9Q%1M}4vskXZzp62L znCLv_#knX0g!BEG6i@Gk>x~90DmuIj$Im4dt{4?o$UQN#)46(ex@Pv`9+`o&*HuSLjFo0d<=z_FmWxduP;z8i7e&(a#Tx)1qs#H0JmVbC9HknpPf=U`URxWf#Whb+x()+W#ese1& zbe1up8__3z7bbamz5i8!ISZqGBpIKDcOwU~5Bk=ub5L@q6G1^|^5r%|LvrQ|TNX*k z=_jyTbdH#$pgKLN&@9&MD=ce{H6jVlf!XFPx5R@50%Gn2$%{Yuz?B|&`^Dai&=A+D z$5oD28%dbM<~Ke%Iy>->XPmy@_$B{xOhuIYndRAqs8a$_t!&Y(q&V`3FU#UVChT06 zho0&5jy~y(_*~iTQ9?UW>YTt_q3?tdTYVhHheF7seFD2+zUG{h2IAwmCDb$O}106JJ?hFzv2?v6JKQa?{{;pW{nfc*mm%>p{zo zh3ooHUA_qKS}g3=Pa+YXzWnyTrZj(m)eAI^*Huy#^9?>7BtA})aJm-36>DE0WTz;<@Z9yLv9!LzG@=~pl`Wx zxgmx7TC>w(UxUj|mo%&jR`rOeto^)6>NM8hZ<~*vdg;}jbjhtd>CszSdHPb#yY?Lx zUr&Cn@d&F)-EyqI;?h<(B=`PCiH61a)ewQk^X0N3!L0ZCZ8BdDp!wF0EmVfRpV;Ma zwqJg5?pYUS{gLNVp3In0lImmEmP#Fj#^pcDG|nXtd>EQ{lFOhj{1nratouAMt9>|! z`2AY1R=^EW#+K_>`;f+rwe+-!6}K>CyZ+gRh7}U*N4m8O)+n$(seMuLQjl^v<{Fa4 zqQ(>26OWuP@mMv}##&Lwo|J>!8|>)#SdcTng|2tnF1z2#jK|R3DN0_eplYi?o

Y z(N--%_QY?0S?jDDZ+Hj{XRp0Kxn5(v@`*a~5rJni za}jS{w1#4yc9#VwtB7y^!Xq#HIdrY_cKG_xjZrpMR4NK`LV8ZG-cr5Ap2)H&G5ryq z5{>=Y(ib|O1jn+tCmyH=u9l~*d<`$7^Jw?sk3XHI<0HIt#7ow7DNkg#{f+T!U)C^> z$B7+TdOpH@OkR?%w5z8nE7P0tl=%H|40hQ#fP79{F!S==)*HG3#vL!v!q936;CR_e z8(V~5@C!q4K$LO`*-Xp8UB4%7 zJYNW|{*s*9_Zo6e46TVg{a#2ofF?LLuS|G*EcbnoMGE4m6GcFYFacjdQFyRqp4{+) zH14qP`XYb&nViV2I+`g?WnaY`E#gcASL32fse$TiDDt8pj-0I31o$r=0+8sQXy6Nx zJIBAVBpi@N8NXX^0Cd?d7-VX#EUWql@23UsE*$gzr&}N2#n&5ulMNwzsN%u}Izs#x zxgSC=XVvHUB~arXS9M<3dh?(u| zxy_5cw>KAUFFc%%cg(k_iAgSlU+W@R8REE*ZPCW7j?^oRaTQv2KLY?34z_@|w}4C` ziWV=tOpGSIn&Olp)YoyJB)`UqRdv25;_`TID#FGr*PMRGR+Ov+-IDtDw)>U%SW)K& zUiFfh1WNsh*Ci)CDlBB!P?%6Lp3gZB!C9Nj`m1(b(W+K{jdUt>MO zxhvQ^>0#ezh1l9>JaV$>l_phdA_!k2cW+gutcVES*M`|L+)6pGY+xw7{&&X{`3#<) zVxOpV3O4Fj;;JY#Jsi#wZP219x4Sq!-9A~?eto#)i7ZiK>|1xOa_-^$DnsJ40(%=m zG3%yGL$%#!EeEN}JUH3R%>CR>A?FCoHPv*-Z*c5+R+c6T6eZt`;L=aCi75&jwNH~& zj6paSdkZ&|ES#!*zZL{O%ev6c?*lr%!l~+K4!D@Nny@{nu5SGj*5FizST5 zh~vG^A71pQk@(!CaO4uhR~;i#FXql<|K%XD8fLs3vDZHjzSH6=sqO3X()a z!R5mm)2SuLyn~~!KdftvPtbjNIl@%wM)-ATYG~KOJqMn|pe5(ntsC7o;fW!2_~q6* zEw$7I@N+6!fkc7Zh|9JkT(dbYWr;qLD_bF}MJxk2zh2Ja6CJNi4V+8X7vXh=r>q?D zgGi5`iXalId*dk_EO>uHo`v`PC^XIUV`0&N+4LpEKzwocmhn*S)0gUE+~s_d$z+g< z+*{4gO)c|xv@Ooux#&s!{C&S znHUQ~NlR&Q?N_-ktPEaN>b4t3$5f~6Ryd#biwk;?FfLDpb2wEI=EnD!O`oOBVla46 zzjJf6vdM&{O^A zJ#DZWCN0&kP|? z2E)C}Qrs4@SjuoW2d!2zp2VY4Y`(6s3*qVbdBXQ$&yj1>^zOWlQT#2 zZ!ek7_b@fhkC<%LGFY#Z(ep>dpwV%+q+u zx_V*pj!*#9n&|vwDo-EH=S3+OL;SI+4dyPMsn&PK0k($Eq8cHTA6pP^uy%|ChLb?5&8b>_~bl~zKW>oyPk`~34!DR7LopAF^NY-#_kaO*t0pRufo|ytSd=*i|JqX%5JG}ZVALwZ2rx0_>&0JW-|4X-C`xokyn9@5WWDYlhbDvbReOj&oEfj`YipXIL)D z=(`yz*;*T&cBRA+M3J83uR6NT_?Clu-QH^wmN|7HxygVBH~ z7nL3&zklm3oW5FEf0xc2%RPY4v|znM}146G4G{OmK;ZtbD?{HCi)w?)X7 z_X?TmD3|L_Fv;mpzi$+$HY`^j5)GVbkTfJp_FcI-n6h=Ic>2<)WLDfIbv-BrT;;rr z3g*nFG^PsD=auEBmA9$=@ud63Ekf(ctqPV-IE%To zLw$uX3`fPS!wmzno^Ld`+M=G7p-WivZc5$vUy~_0bLSAI&|y0XC0Yd>k}es z;&Yj!0daWte#NY=#s-FXRHD>0X0xyOPld=@P(Q!Jl1uU8yyJv;I=_N!Dp~2vSCavA zc%K&NmJZZB|_?;VJ3CKibMd+wjDG!dp+I$Jl-n^)-th0?elYpCA1ZWof!A+osDdX0vD;p0)#N!*JAu_=?zPkkfLun1$3v6+`? zMR!{{&7_RwsZYpXscxQ5c5;8#eq}VizB;gdhj&ID`;h%g6#8V6#FevklJrYOJ(&-! zI&|2_Frf+t6mbp7P49Ka+oo@?6zb0maMX2fHdH*PJ*PE3&)zcJ?mwa$^Y#VyQ(=7H zjzeikfajFdH5O?H*N0Gz?7{9)HP0_%!B-3GY{|%nP?^N2nz($3Kmk)KRoT-ra{B+2dRuq(dxedK)J%RrTJZ z@8{tn!nkN<-oKi--1_)vZelk-n?=z(7ga91Q-KWht!p_S-eeA65$)M3y2!jP zJ$wu4#QNT<`E7sebmZ5GFnqa)M^+3UCkyQcxs0pS_#{lt6rJ*|Lo^gOl__6VHdGUz z_+%p?#^xL~uxQlKef&zi4Yh~t4gr1UIEx&q=E$oY#gzhqk>1U7xK3=_!2%f1>Cf--n%3?US;2hpzCE z1s>4;0hr{k9u^SC{o`f_-&e6V-3#2<8&YkjJSNOs{)E{S$8>}nzJ(1M+luDNWbv|e z(%R)v3XrNVn3SRPO_(xt^XoXgh{naGXCWd`T7fX{C2hS*9(5+whwolKj$B}{X9yc=yq*U<_$&ATzr15`3zM| zlQyBCN~Ik9M&zPRI%|5>g}XOxgf0h(CtR?c46duY&`Y{hxokPD!C+-D{=gQOBPrYR z@;QluLiDk1aU#od)O+f$-aPVP~ z`1yke<>=vU@2Am8ijf}}+Fnx7C3z$qMFxP?yYT8n; zbYg#(@`hNbvwGHgTR7tFJZ}_hTFA=AD-JiT7wNm>=r9JVxh-~2KlIIvrK>AP&Nzwe z8H{hyOFWxrhMpy&9#k=XpYq1vsM^hizU%_Iv*;ZWE((UO1@`+V?kMM%WZqL%amW>V zFrRqmHZ*p%hH%8}=L zPx?v`!($6)m<#Sb@_im6x@tH5?q)sop_nSP+mkHFW$8XTih_?W1$oH?gA1>ffD@)H zxZD>!%7Xn|5Ak$hcGrWNW4RTr#~Lj-3!PF`4Q|W68{2-# zwmlZYXVHhxL0&m&Mxnsl@k~ZFlqMFjNPpjr@kYKtYbM{knL>9yQ_{#wrB_K?v49=J@XT35;p6Pf-w~g#2&M~T~Y^mmOnTi?&o$X0K`9U!K3b{^k^(vMl zGs(H;=(&I=_{j9bxEnHv%AB`x3>1pQuC%1OS@ln@m^Mu7_6dZI*?pF@FlX^6i7*o# zJM9w$eRt&XX^vPlngXHIjkiqFMxtvxlzUbs#P4G4AdyRkW@BFGWXan>H;fqP+GNg+ z)g3F%1ml1{xlh_2Mgc;Jxk+yCqeVTZ6!;lC=^LNrC$fy!a3^2rnoj_bSo+EwlPjI~as^mF> zI2J=H?W3}|G35y1>M1AAwlM-S2UhRVUiFmtm+sb4G8$}SS>#vuNG(|+bFTK`Qa6v| zJ~Ig%z~up;6p8wakwrrF8~%H)=HM$ml@R%{G!OK`ddczwU6Vw{HXa>n_s!aiDw&zi z55+3oU=+*0zl(A@&v<`JhKRr|E~4d%Tz=S6R1W3c(bbvab;w#P+XNP(mlp*5HY_Ea zBt!c{c=;HjbzflpreEBn*PYSF6i?eS6<2D#et&K%gx|HePHgE|16^ZbPk zA<;;6z)=^g!w_f{7#elNlp=sqjO*U9AG2iPJf|o3+MSJo2=U~_A%fGShyz+b0Mlar z0@5N870SR29d(%e7Bvgxtp(x!480~dxKfnslr(W;`{5E|pHk)&zaUGWM5yhFCMdOd zaZ)giI9_=;3Bn%XXte=FmrBW;Mso+hl1 zkp7|PJqF&JPwNV%jN6~OP^_rU3^b>Z6XkfwD#SlM?%7{j(_dpX`Rv_Wxq;wIT|K8# z$)3^Cj32<-0f_2v)(s96Rs0D}u2b?HM>~)7#XCS4)9T^~4^>){upZFtK|qR!_rFCV z5e3XTpvS~F)BKoxc4q}*WX<7~duWnwGqhVm&KG}Kru*a^bCg$MH`>;^%*q$sv6}p6 z2`D8lB(`bH7aw`Zuhy7!CHV3?ocZjNm;8iHXS)rsxn_=;p|0w#><;ohai3Jix7~DN zZ%C1vLx%I6X>BIkZ5vLrN+uV)C*2H&DzU0bH7!F&J{)BYyAk*BqC=|9MxPvv6#t7T zMdECT@H=AP<}S7MS@NZs8MirKWBaQorAJP*MlqbG^K(np-KU6!bMFKn)x>hf ztkeDX-P>XgpNu{|Hyy_&2m7|Q(;e$tR2$ny^JhGL+V(Z8Wh{Tl>*A7ziM5Vg2T?`p zGYes&F*WSyL%~buAbcAK^z|Tk1pf=O5mc@W!-o@h7zCWruJv5bwWJP_yT7+Cy}iNM zM=hbs&XW|)#_bx*tfMx;cRjw5^y;lQc4la1b_wp7BIeCr$m)&Yx0IXX1Mwwds01sl z*Ks6!-dVbra_Bn_>=aM$jWr;5NNip3Gz002^`O03#zW;$rxvo)#C1s&A&kVb3WR%) zoFeWThMW~@SQVp;iY9%_v#0%K&6%W7ErK=5&c0^aFeR51?<8&Xohl{+?P4zSlh!;U z3O%*DoIcMg48NW&WqL|1yb;{!ct>lnntZKsti0Ayzuan&xLh+&+M z{OVmWe~%9TA780m@vo9J4~LUO)w4zw2W_{h^CB|kY0*B zLm7dg^wuDyz*z+|zqFPhx+gJh^LQ(&dd05uhnG)ioHr;T#uXPlbr3Yc=V|h!dix26 z3yzk$t0DL9FagEZi*4`sb}Mc-4jVVY_Lug5-3wsD zes9lX0j1mb0$uj5Fh3w5PK33wj!@r>3Kf$#UP{DGF{>*T2%TrDcq)TEk*kcsSF(~P zMy6!RO~x(MmRU<>*QQX^Z(b!<`MG+hF>G9N5aDeiDP$K=G(n3J~VtBotRw);^f zp6Yjc-SectDV>78{9;1}e^ZcCGL7}tmz6Axfzg-;Y~WAz0h;==f9StZpF|{<91ngB z-?`e4Ra)O~Gryy(Zf7s|=%7@_!VwhhQ)-2FPf9DHg<{F{r@dF^Ib@^7>!_pYlpo-N z_%BGyOXH%Y%js`2-O(;FNK8^^FHxN;wQj0^KB7C&Djwz#wi~=asJ^STs2}@5;bvhp zeCM=!75~jpnJq(!FmeNBUJOUrXF^WGS^9TKclSjeYnR zZQ0%cJAKL9Op$kU83pcTwRN&UrXy15o!IeA7VkWAaXPEv_#my5Cpn}7DP!@jM!Gnu z{;C0Q$yZ`Ei?XY%Jx&?@xyM$DE#-Q9v5(TK<8`cDk}0aP{Ck=el@x2Vknb3|MClu4#UugHzmRMhTx(tQ2A_I{P$>+4|-=j+Y~sO10@g+%;?{qX>Zp8#0#?O}T& zoc%$m`(s{SNX=*aD-_jK@-X8WHg264ay9+Tud*9v_0DZF7mH3rEZ^x=ZQ7(Rq3q!- zl#KFFg7R$DO_@VCFK->C&qhL(=Y*S|nIcs2 z{LT+MTRR>Lc47(k&Z?YQ+t!lJp-BTdr671;?W83 zZWN2I->Quk5R!to`7EJ>Z8Hd1=IFDTI~ytaidv3TS=<@Wm^ZGc4`d6~I$n}PZ?mRH z5tl_*Fher}o9KY4nWj$_5%j+`HTQe}W`?J*K4@7b#-PFBeSDYuu|BJ>XQ|sKaa7Sh zQV)vBNvwHmgKE|U*mK&7L0pOfT{P~thdsK)Z^ zs=IXr2B?=Uxonrljgn(#-OU3UiP|2LFRR{?&Z*L57+6U{6jtz9;C@6d`uWrofLl?LZ=x_?Ip5$Pb%!j#%XnJXx}oej6-O!a+Hso@2E5}_8Nzq%DHvq z1vth0=vj#|5ntx$L~PSjm-W|My~X#w4u@i%^RX}J>FVHK2%TMyZRS8Mcxtb^d>m}K zc^tKpu{$?FOpY=Si*7zzFp9mEk))#OQ6UyX(gkl}_(+=8Ktm+EBS85)f|-IYw)HhB zNxHWYeP(D1dg932j7zpuM@gqk8VoCasbo)GiQYV%D?jhMuvfTRRx>noj&9kf`2gh{ z1SS8*TXU+&IN^yLd)9Mt-QVKDYh1#`z}dZ_W9>%?NkW8Obf@oy-i2~#RlSLNXw~o< z3HN(V5a@Mr$X0t+SRE>ZQ|3u`TlHM0A9hS6>WmpIBDTmU+YMzKn{S$e(^ysvnfISq zGAI;G5^2lVZ)B!8B|l)oWJ!7{3BgtBEQo>5NopsUMA~VdL6V1~0;CxN@{+T<@kPi( zA4!IxohUzNb4R$`(g@3eFFAAC8ucjfUg=%Z&R?62r8jWoVm-y3PrLpJjdx1LI)6=A zy7r45*(c}6F|R%3=DMk0&RpwiTdP+$Zb}}q6GjhAzpe_J`5am}H@CfuXJo7wbbt&G zz?yK(UpPvINEuSG{k^05Q#t;~+Bt#ec~7lYo)?H6b@knyg*pfb^g5pBaxt3^4ZoO7 zJ5yT%#ZC9|?zKG;YJ`i7XK-pKxV4OQk~`NU;77_O+eQR3YROvkkCB@toU5sYG8#l| zYkt&;S?{UfACGuiYGEXpQ5(E;tf}a*md}@`ExvN~+lDV^6WEz?6oG5Li;Svh{`3=N z>Pt>}w@b#8FFH`I@AiqkUnVGH5_o+;w-3OvynnI6$OAxZ9~iOtHKnPO)6YkaGC+o( zH_15QUXRBZs!j{P;X4_A-Rzyop`LKKaVUju`2YqFK%xlvU+s-zNr5Th(kAe_$@x%j52cI?U0%=FA?+6?g z8|uqZ5j#cS+n=p=p=bK5=);86{7zkPseE@3JFGm?eO~p6z_CVFA(JY%_s9O z<5ls0fhit3MB&(R#O$tYY{_h+eerq;q1MjV!*`|Bau?h4{QVQVu~iVz-czo->d#sUi8}uHcmJUV2=|8%;4b|EDAVw2KYbV z(_=5!D z7Pq#rFb7W7fk)}|px~abvx5l^IOPZQ&jCDgXJc;R@;&KfZE58KMFBQsb=ASj+zyHY zX5F6VPWDi4dpn?0F780e#mdRt9Kvk@9M6MtTUfgRbZ3BHG_W};21Nr;hy>6ND7O?8 z0|-L~3cf?207wN2zSg4xxy>Nl=D^83Agd)10WDht5s={ul-mZ% zZ42eL1KNhd;lOwUGzTaM;MEnl21>gFVgtygfUXR{Q+?c8)@CkNz&Stg!T9+EXkqWdcqia6uIV~kh?z#-TIc)^K40{fR@CJxf(`x1sh zxV3*I0T}_uH?46Zc9u5g0OQ-2$27; z^&kO6hysih1@3>bJ2Pz06-0t6@;*hJ>x z0V)BWIt*ah7+!z@!1DHE@ILUa9vV2FiUG=lbz*_x9?koW2FClHL1G|a3^DQ zfv~R+u(@vm!t-5c|5WuY0WkQj3gA^ouow@>6~H-I95jO;>iQ7_DueuBWWNR!2FL`w zPzkaR+QC2dgSkNBzA*tcf!F`k0x|`ILH6%`uK)`SG(Q4p2;Xb~w1Z#6zN`Ve0Bz!@ zux}5bP5cP^w({L3KzsN(fPC#=19gD5@gsmeN5gmk9Rj?9HgG_|@&Y;m?E-MjKZhSS z@goL>0_YvGAA{+i!oIG*-3Rmh6F|EFwF<`nzJ&Vo{8mpeX{f0aFlQ(B6Ly3!te2n*Mgz1oR~P`M}{9a9;aq z928)0KV=3y0pC*QfTF<^;4wg70X)R}8TQ==AlUs2OTg~GrGEIG@4nR<3T*c6XSo6e z#=!k*Y@n#`&Dlax-%1iSyk^D2NN5pWF%PXT!JzpG9G z0R^6oZ&UyceV3d9fQi0~PJsss3;P56 zVdd6pr$nbd&G@cK)bn5USydo!7Riub$QHS!AvgfzeyC8Ql_B{uQP#*!d4bUzCOLIF znrG&U1HMD@b1?@q3#h3H?h4S%k1Bwy`wdvz+lhl;aJeAPhXCHK zfT7?>7y<#~L2$xQtS}fW+fPHNhm*Ml1Oz4!07HNO10XUF;I}QH|Ionk82p3UeW$@O zzz0`;rvWGgFq2>Np+T7ZJBR6aayLr=dZ3^;>;dV9N9x4GjaP zT))#0AiVmuEMNz}^9(~{{*V 0 else float("nan") + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot ablation studies") + p.add_argument("--predictions", type=str, required=True) + p.add_argument("--output", type=str, default="figures/ablations") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.predictions) as f: + data = json.load(f) + + entries = data["predictions"] + T_overhead = data.get("T_overhead_seconds", 0.0) + + # --- Panel (a): topology sweep (SBM inter-density) --- + # Filter to SBM entries, group by inter_density + sbm_entries = [e for e in entries + if e["config"].get("graph", "") == "sbm"] + + density_groups = {} + for e in sbm_entries: + d = e["config"].get("sbm_inter_density", 0.0) + density_groups.setdefault(d, []).append(e) + + densities = sorted(density_groups.keys()) + mape_full = [] + mape_flat = [] + + for d in densities: + grp = density_groups[d] + full_errs, flat_errs = [], [] + for e in grp: + T_meas = e["measured_median_seconds"] + # Full hierarchical prediction is already in predictions.json + T_pred_full = e["predicted_seconds"] + full_errs.append(relative_error(T_pred_full, T_meas)) + + # Flat model: use a single network term = T_intra + T_inter (not max) + # Approximate: flat model can't overlap, so T_comm = T_intra + T_inter + bd = e.get("breakdown", {}) + T_comm_hier = bd.get("T_comm_seconds", 0.0) + # Flat approximation: assume both intra and inter are sequential + c_intra = e["partition_stats"].get("c_intra_bytes", 0) + c_inter = e["partition_stats"].get("c_inter_bytes", 0) + # Without knowing the individual bandwidths, use ratio heuristic: + # flat ≈ 2 * max (conservative estimate) + T_comm_flat = T_comm_hier * 2.0 + T_pred_flat = (T_pred_full - T_comm_hier + T_comm_flat) + flat_errs.append(relative_error(T_pred_flat, T_meas)) + + mape_full.append(np.mean(full_errs) * 100 if full_errs else float("nan")) + mape_flat.append(np.mean(flat_errs) * 100 if flat_errs else float("nan")) + + # --- Panel (b): with vs. without T_buffer_copy --- + meas_all = np.array([e["measured_median_seconds"] * 1e3 for e in entries]) + pred_full = np.array([e["predicted_seconds"] * 1e3 for e in entries]) + pred_nobuf = np.array([ + (e["predicted_seconds"] - e.get("breakdown", {}).get("T_buffer_copy_seconds", 0.0)) * 1e3 + for e in entries + ]) + + fig, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(9, 3.8)) + + # --- Panel (a) --- + if densities: + ax_a.plot(densities, mape_full, "o-", color=COLORS["full"], + markersize=5, linewidth=1.2, label="Hierarchical (intra+inter)") + ax_a.plot(densities, mape_flat, "s--", color=COLORS["flat"], + markersize=5, linewidth=1.2, label="Flat (single-tier)") + ax_a.set_xlabel("SBM inter-block edge density") + ax_a.set_ylabel("MAPE (%)") + ax_a.set_title("(a) Hierarchical vs. Flat Model\nover Topology Sweep", fontsize=9) + ax_a.legend() + ax_a.grid(True, linestyle=":", linewidth=0.4) + else: + ax_a.text(0.5, 0.5, "No SBM data found.\nRun bench_end_to_end with --graph sbm.", + ha="center", va="center", transform=ax_a.transAxes, fontsize=8) + ax_a.set_title("(a) Hierarchical vs. Flat Model", fontsize=9) + + # --- Panel (b) --- + lo = min(meas_all.min(), pred_full.min(), pred_nobuf.min()) * 0.9 + hi = max(meas_all.max(), pred_full.max(), pred_nobuf.max()) * 1.1 + ax_b.plot([lo, hi], [lo, hi], "k--", linewidth=0.8, label="Ideal") + ax_b.scatter(meas_all, pred_full, s=18, color=COLORS["full"], + alpha=0.8, label="Full model", zorder=3) + ax_b.scatter(meas_all, pred_nobuf, s=18, color=COLORS["no_buffer"], + alpha=0.6, marker="^", label="Without $T_{\\mathrm{buf}}$", zorder=3) + ax_b.set_xlim(lo, hi) + ax_b.set_ylim(lo, hi) + ax_b.set_xlabel("Measured $T_{\\mathrm{layer}}$ (ms)") + ax_b.set_ylabel("Predicted $T_{\\mathrm{layer}}$ (ms)") + ax_b.set_title("(b) Ablation: With vs. Without\n$T_{\\mathrm{buffer-copy}}$ Term", fontsize=9) + ax_b.legend() + ax_b.grid(True, linestyle=":", linewidth=0.4) + + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_ablations] Saved {out}.pdf and {out}.png") + print( + "Caption: Ablation studies. " + "(a) MAPE of the hierarchical cost model (separate intra/inter tiers) vs. " + "a flat model (single bandwidth parameter) across the SBM topology sweep. " + "The hierarchical model degrades more gracefully as inter-block density increases. " + "(b) Predicted vs. measured scatter with (blue circles) and without (red triangles) " + "the $T_{\\mathrm{buffer-copy}}$ term, demonstrating that omitting this term " + "causes systematic under-prediction for large halo sizes." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_compute.py b/experiments/cost_model_benchmarks/visualization/plot_compute.py new file mode 100644 index 0000000..20ef9ce --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_compute.py @@ -0,0 +1,172 @@ +"""Visualization — GNN Compute Primitive Runtime. + +Two-panel figure: GCN-like (left) vs. edge-conditioned (right). +Each panel shows forward runtime vs. the swept variable (vertices or edges) +with the fitted linear model overlaid. + +Usage:: + + python -m visualization.plot_compute \\ + --gcn-vertex data/compute_gcn_vswp.json \\ + --gcn-edge data/compute_gcn_eswp.json \\ + --edge-vertex data/compute_edge_vswp.json \\ + --edge-edge data/compute_edge_eswp.json \\ + --primitives data/fitted_primitives.json \\ + --output figures/compute +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +COLORS = {"gcn": "#2ca02c", "edge": "#ff7f0e"} +plt.rcParams.update( + { + "font.size": 9, + "axes.labelsize": 9, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "figure.dpi": 300, + "text.usetex": False, + } +) + + +def load_compute_file(path: str, timing_key: str = "forward_trials_seconds"): + """Returns lists of (sweep_value, median, q25, q75).""" + with open(path) as f: + data = json.load(f) + sweep = data["config"]["sweep"] + rows = [] + for meas in data["measurements"]: + trials = np.array(meas[timing_key]) + rows.append( + ( + meas["params"]["sweep_value"], + meas["params"]["num_vertices"], + meas["params"]["num_edges"], + float(np.median(trials)), + float(np.percentile(trials, 25)), + float(np.percentile(trials, 75)), + ) + ) + rows.sort(key=lambda r: r[0]) + return sweep, rows + + +def fitted_compute(sweep_vals, fixed_val, sweep, model_type, primitives): + params = primitives.get("compute", {}).get(model_type, {}).get("forward", None) + if params is None: + return None + a, b, c = params["coeff_V"], params["coeff_E"], params["intercept"] + if sweep == "vertices": + V_arr = np.array(sweep_vals, dtype=float) + E_arr = np.full_like(V_arr, fixed_val) + else: + E_arr = np.array(sweep_vals, dtype=float) + V_arr = np.full_like(E_arr, fixed_val) + return a * V_arr + b * E_arr + c + + +def plot_one_panel(ax, rows, sweep, fixed_val, model_type, primitives, color, title): + xvals = [r[0] for r in rows] + meds = np.array([r[3] for r in rows]) * 1e3 + lo = np.array([r[3] - r[4] for r in rows]) * 1e3 + hi = np.array([r[5] - r[3] for r in rows]) * 1e3 + + ax.errorbar( + xvals, + meds, + yerr=[lo, hi], + fmt="o", + markersize=4, + color=color, + capsize=2, + linewidth=0.8, + elinewidth=0.8, + label="Measured (IQR)", + ) + + fit = fitted_compute(xvals, fixed_val, sweep, model_type, primitives) + if fit is not None: + ax.plot(xvals, fit * 1e3, "--", color=color, linewidth=1.2, label="Fit") + + xlabel = "|V| (vertices)" if sweep == "vertices" else "|E| (edges)" + ax.set_xlabel(xlabel) + ax.set_ylabel("Forward time (ms)") + ax.set_title(title, fontsize=9) + ax.set_xscale("log") + ax.set_yscale("log") + + ax.legend() + ax.grid(True, which="both", linestyle=":", linewidth=0.4) + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot GNN compute primitive results") + p.add_argument("--gcn-vertex", type=str, default=None) + p.add_argument("--gcn-edge", type=str, default=None) + p.add_argument("--edge-vertex", type=str, default=None) + p.add_argument("--edge-edge", type=str, default=None) + p.add_argument("--primitives", type=str, default=None) + p.add_argument("--output", type=str, default="figures/compute") + return p.parse_args() + + +def main(): + args = parse_args() + primitives = {} + if args.primitives: + with open(args.primitives) as f: + primitives = json.load(f) + + fig, axes = plt.subplots(1, 2, figsize=(7, 3)) + + panel_map = [ + ("gcn", "edges", args.gcn_edge, axes[0], "GCN-like"), + ("edge", "edges", args.edge_edge, axes[1], "Edge-conditioned"), + ] + + for model_type, sweep_label, path, ax, title in panel_map: + if path is None: + ax.set_visible(False) + continue + sweep, rows = load_compute_file(path) + fixed_val = rows[0][2] if sweep == "vertices" else rows[0][1] # E or V fixed + plot_one_panel( + ax, + rows, + sweep, + fixed_val, + model_type, + primitives, + COLORS[model_type], + title, + ) + + # fig.suptitle("GNN Compute Primitive: Forward Runtime vs. Graph Size", fontsize=10) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_compute] Saved {out}.pdf and {out}.png") + print( + "Caption: Forward runtime of a single GNN layer vs. subgraph size " + "(vertex sweep top row, edge sweep bottom row) for GCN-like (left) " + "and edge-conditioned (right) message functions. " + "Dashed lines: fitted model $T_{\\mathrm{comp}} = a|V| + b|E| + c$. " + "Error bars span IQR over 50+ trials." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_gather.py b/experiments/cost_model_benchmarks/visualization/plot_gather.py new file mode 100644 index 0000000..cddfdc4 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_gather.py @@ -0,0 +1,120 @@ +"""Visualization — Gather / Scatter-Add Bandwidth. + +Single plot with three curves (contiguous, clustered, random) showing +gather (or scatter-add) runtime vs. k (number of rows gathered). + +Usage:: + + python -m visualization.plot_gather \\ + --contiguous data/gather_contiguous_*.json \\ + --clustered data/gather_clustered_*.json \\ + --random data/gather_random_*.json \\ + --operation gather \\ + --output figures/gather +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +COLORS = { + "contiguous": "#1f77b4", + "clustered": "#2ca02c", + "random": "#d62728", +} +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, + "xtick.labelsize": 8, "ytick.labelsize": 8, + "figure.dpi": 300, "text.usetex": False, +}) + + +def load_gather_file(paths: list, timing_key: str): + """Merge multiple JSON files, return sorted (k, median, q25, q75) arrays.""" + rows = [] + for p in paths: + with open(p) as f: + data = json.load(f) + for meas in data["measurements"]: + trials = np.array(meas[timing_key]) + rows.append(( + meas["params"]["k"], + float(np.median(trials)), + float(np.percentile(trials, 25)), + float(np.percentile(trials, 75)), + )) + rows.sort(key=lambda r: r[0]) + k_arr = np.array([r[0] for r in rows]) + med_arr = np.array([r[1] for r in rows]) + q25_arr = np.array([r[2] for r in rows]) + q75_arr = np.array([r[3] for r in rows]) + return k_arr, med_arr, q25_arr, q75_arr + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot gather/scatter-add bandwidth") + p.add_argument("--contiguous", nargs="+", default=[], metavar="FILE") + p.add_argument("--clustered", nargs="+", default=[], metavar="FILE") + p.add_argument("--random", nargs="+", default=[], metavar="FILE") + p.add_argument("--operation", choices=["gather", "scatter_add"], default="gather") + p.add_argument("--output", type=str, default="figures/gather") + return p.parse_args() + + +def main(): + args = parse_args() + timing_key = ( + "gather_trials_seconds" if args.operation == "gather" + else "scatter_add_trials_seconds" + ) + + fig, ax = plt.subplots(figsize=(5, 3.5)) + + for dist_name, files in [ + ("contiguous", args.contiguous), + ("clustered", args.clustered), + ("random", args.random), + ]: + if not files: + continue + k, med, q25, q75 = load_gather_file(files, timing_key) + color = COLORS[dist_name] + ax.errorbar( + k * 1e-6, med * 1e3, + yerr=[(med - q25) * 1e3, (q75 - med) * 1e3], + fmt="o-", markersize=3, color=color, linewidth=0.9, + capsize=2, elinewidth=0.8, label=dist_name.capitalize(), + ) + + ax.set_xscale("log") + ax.set_yscale("log") + op_label = "Gather $x[\\mathrm{idx}]$" if args.operation == "gather" \ + else "Scatter-add (backward)" + ax.set_xlabel("k (millions of rows gathered)") + ax.set_ylabel(f"{op_label} time (ms)") + ax.set_title(f"Buffer-Copy Bandwidth: {op_label}", fontsize=9) + ax.legend() + ax.grid(True, which="both", linestyle=":", linewidth=0.4) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_gather] Saved {out}.pdf and {out}.png") + print( + f"Caption: {op_label} time vs. gather size $k$ for three index " + "distributions: contiguous (best case, cache-friendly), clustered " + "(METIS-partitioned halo pattern), and random (worst case). " + "Error bars span IQR. The gap between contiguous/clustered and random " + "quantifies the cache-miss penalty relevant to poorly-partitioned graphs." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_pingpong.py b/experiments/cost_model_benchmarks/visualization/plot_pingpong.py new file mode 100644 index 0000000..b7625a4 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_pingpong.py @@ -0,0 +1,147 @@ +"""Visualization — Ping-Pong Bandwidth / Latency. + +Produces a log-log plot of one-way transfer time vs. message size for intra- +and inter-node measurements, with fitted lines overlaid and a residuals inset. + +Usage:: + + python -m visualization.plot_pingpong \\ + --intra data/pingpong_intra_*.json \\ + --inter data/pingpong_inter_*.json \\ + --primitives data/fitted_primitives.json \\ + --output figures/pingpong +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.ticker import LogLocator + +# --------------------------------------------------------------------------- +# Shared style +# --------------------------------------------------------------------------- +COLORS = {"intra": "#1f77b4", "inter": "#d62728"} +plt.rcParams.update({ + "font.size": 9, + "axes.labelsize": 9, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "figure.dpi": 300, + "text.usetex": False, +}) + + +def load_measurements(files: list) -> tuple: + """Return (byte_sizes, medians, q25, q75) from a list of JSON files.""" + byte_sizes, medians, q25s, q75s = [], [], [], [] + for p in files: + with open(p) as f: + data = json.load(f) + for meas in data["measurements"]: + trials = np.array(meas["trials_seconds"]) + byte_sizes.append(meas["params"]["message_bytes"]) + medians.append(float(np.median(trials))) + q25s.append(float(np.percentile(trials, 25))) + q75s.append(float(np.percentile(trials, 75))) + order = np.argsort(byte_sizes) + return ( + np.array(byte_sizes)[order], + np.array(medians)[order], + np.array(q25s)[order], + np.array(q75s)[order], + ) + + +def fitted_line(bytes_arr: np.ndarray, primitives: dict, mode: str) -> np.ndarray: + net = primitives.get("network", {}).get(mode, None) + if net is None: + return None + t_L = net["latency_seconds"] + B = net["bandwidth_bytes_per_sec"] + return t_L + bytes_arr / B + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot ping-pong results") + p.add_argument("--intra", nargs="+", default=[], metavar="FILE") + p.add_argument("--inter", nargs="+", default=[], metavar="FILE") + p.add_argument("--primitives", type=str, default=None) + p.add_argument("--output", type=str, default="figures/pingpong") + return p.parse_args() + + +def main(): + args = parse_args() + + primitives = {} + if args.primitives: + with open(args.primitives) as f: + primitives = json.load(f) + + fig, axes = plt.subplots(1, 2, figsize=(7, 3.2)) + ax_main, ax_res = axes + + for mode, files in [("intra", args.intra), ("inter", args.inter)]: + if not files: + continue + xb, med, q25, q75 = load_measurements(files) + color = COLORS[mode] + label = "NVLink (intra)" if mode == "intra" else "InfiniBand (inter)" + + yerr_lo = med - q25 + yerr_hi = q75 - med + ax_main.errorbar( + xb * 1e-6, med * 1e3, + yerr=[yerr_lo * 1e3, yerr_hi * 1e3], + fmt="o", markersize=3, color=color, label=label, + capsize=2, linewidth=0.8, elinewidth=0.8, + ) + + fit = fitted_line(xb, primitives, mode) + if fit is not None: + ax_main.plot(xb * 1e-6, fit * 1e3, "--", color=color, + linewidth=1.2, label=f"{label} fit") + # Residuals + residuals = (med - fit) / med * 100 # percent + ax_res.plot(xb * 1e-6, residuals, "o-", markersize=3, color=color, + linewidth=0.8, label=label) + + ax_main.set_xscale("log") + ax_main.set_yscale("log") + ax_main.set_xlabel("Message size (MB)") + ax_main.set_ylabel("One-way transfer time (ms)") + ax_main.legend(loc="upper left") + ax_main.grid(True, which="both", linestyle=":", linewidth=0.4) + + ax_res.axhline(0, color="k", linewidth=0.6, linestyle="--") + ax_res.set_xscale("log") + ax_res.set_xlabel("Message size (MB)") + ax_res.set_ylabel("Residual (%)") + ax_res.legend() + ax_res.grid(True, which="both", linestyle=":", linewidth=0.4) + + fig.suptitle("Network Ping-Pong: Transfer Time vs. Message Size", fontsize=10) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_pingpong] Saved {out}.pdf and {out}.png") + print( + "Caption: Log-log plot of one-way network transfer time vs. message size " + "for intra-node (NVLink) and inter-node (InfiniBand) communication. " + "Points show medians; error bars span the 25th--75th percentile (IQR). " + "Dashed lines show linear-latency fits $T = t_L + s/B$. " + "Right panel shows residuals (\\%) from the fit." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_tipping_point.py b/experiments/cost_model_benchmarks/visualization/plot_tipping_point.py new file mode 100644 index 0000000..735c96d --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_tipping_point.py @@ -0,0 +1,141 @@ +"""Visualization — T_global vs. K (Tipping Point). + +Shows total training throughput or per-layer time as a function of the number +of GPUs K for a fixed graph. Overlays the cost-model prediction (solid) and +measured values (dashed with markers). Annotates K* — the point beyond which +adding more GPUs yields diminishing returns. + +T_global(K) is computed as: + + T_global(K) = T_layer(K) * num_layers * num_epochs + +For the tipping-point annotation K* is the largest K where the speedup +relative to K=1 is still within 10% of linear. + +Usage:: + + python -m visualization.plot_tipping_point \\ + --predictions data/predictions.json \\ + --num-layers 3 \\ + --num-epochs 100 \\ + --graph erdos_renyi \\ + --output figures/tipping_point +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, + "xtick.labelsize": 8, "ytick.labelsize": 8, + "figure.dpi": 300, "text.usetex": False, +}) + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot T_global vs. K tipping point") + p.add_argument("--predictions", type=str, required=True) + p.add_argument("--num-layers", type=int, default=3) + p.add_argument("--num-epochs", type=int, default=100) + p.add_argument("--graph", type=str, default=None, + help="Filter to this graph type (optional)") + p.add_argument("--output", type=str, default="figures/tipping_point") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.predictions) as f: + data = json.load(f) + + entries = data["predictions"] + + # Filter by graph type if requested + if args.graph: + entries = [e for e in entries + if e["config"].get("graph", "") == args.graph] + + # Group by world_size + ws_to_meas = {} + ws_to_pred = {} + for e in entries: + K = e["config"].get("world_size", 1) + ws_to_meas.setdefault(K, []).append(e["measured_median_seconds"]) + ws_to_pred.setdefault(K, []).append(e["predicted_seconds"]) + + if not ws_to_meas: + print("[plot_tipping_point] No data found. Exiting.") + return + + K_vals = sorted(ws_to_meas.keys()) + T_meas = np.array([np.median(ws_to_meas[K]) for K in K_vals]) * args.num_layers * args.num_epochs + T_pred = np.array([np.median(ws_to_pred[K]) for K in K_vals]) * args.num_layers * args.num_epochs + K_arr = np.array(K_vals, dtype=float) + + # Ideal linear scaling from K=1 + T_single = T_meas[0] # K=1 reference + T_ideal = T_single / K_arr + + # Identify K*: largest K where speedup is ≥ 90% of ideal + speedup = T_single / T_meas + ideal_speedup = K_arr + efficiency = speedup / ideal_speedup + kstar_idx = np.where(efficiency >= 0.9)[0] + K_star = K_arr[kstar_idx[-1]] if len(kstar_idx) else K_arr[0] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3.5)) + + # --- Left panel: T_global vs K --- + ax1.plot(K_arr, T_ideal, "k:", linewidth=1, label="Ideal linear scaling") + ax1.plot(K_arr, T_pred, "-", color="#1f77b4", linewidth=1.5, label="Predicted") + ax1.plot(K_arr, T_meas, "o--", color="#d62728", markersize=5, linewidth=1.2, + label="Measured") + ax1.axvline(K_star, color="gray", linestyle="--", linewidth=0.8) + ax1.text(K_star * 1.05, ax1.get_ylim()[1] * 0.95, + f"$K^* = {int(K_star)}$", fontsize=8, color="gray", va="top") + ax1.set_xlabel("Number of GPUs ($K$)") + ax1.set_ylabel(f"$T_{{\\mathrm{{global}}}}$ (s)\n" + f"({args.num_layers} layers × {args.num_epochs} epochs)") + ax1.set_title("Training Time vs. GPU Count", fontsize=9) + ax1.legend() + ax1.grid(True, linestyle=":", linewidth=0.4) + + # --- Right panel: scaling efficiency --- + ax2.plot(K_arr, efficiency * 100, "o-", color="#2ca02c", markersize=5, linewidth=1.2, + label="Scaling efficiency") + ax2.axhline(90, color="gray", linestyle="--", linewidth=0.8, label="90% threshold") + ax2.axvline(K_star, color="gray", linestyle="--", linewidth=0.8) + ax2.set_xlabel("Number of GPUs ($K$)") + ax2.set_ylabel("Scaling efficiency (%)") + ax2.set_title("Strong Scaling Efficiency", fontsize=9) + ax2.set_ylim(0, 110) + ax2.legend() + ax2.grid(True, linestyle=":", linewidth=0.4) + + fig.suptitle("Tipping-Point Analysis: When Does Adding GPUs Stop Helping?", fontsize=10) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_tipping_point] Saved {out}.pdf and {out}.png K*={int(K_star)}") + print( + f"Caption: (Left) Total training time $T_{{\\mathrm{{global}}}}$ vs. GPU count $K$ " + f"for {args.num_layers}-layer GNN trained for {args.num_epochs} epochs. " + "Solid blue: cost-model prediction. Red dashed: measured. " + "Dotted black: ideal linear speedup. " + f"$K^* = {int(K_star)}$ marks the last GPU count with $\\geq 90\\%$ scaling efficiency. " + "(Right) Scaling efficiency $= \\text{{speedup}} / K$." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_validation.py b/experiments/cost_model_benchmarks/visualization/plot_validation.py new file mode 100644 index 0000000..65a6497 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_validation.py @@ -0,0 +1,127 @@ +"""Visualization — Predicted vs. Measured Scatter (Headline Figure). + +Reads ``data/predictions.json`` and produces a predicted-vs-measured scatter +plot. Points are colored by graph type (or fit/held-out split). + +Usage:: + + python -m visualization.plot_validation \\ + --predictions data/predictions.json \\ + --color-by split \\ + --output figures/validation +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, + "xtick.labelsize": 8, "ytick.labelsize": 8, + "figure.dpi": 300, "text.usetex": False, +}) + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot predicted vs. measured scatter") + p.add_argument("--predictions", type=str, required=True) + p.add_argument("--color-by", choices=["split", "graph", "world_size"], + default="split") + p.add_argument("--output", type=str, default="figures/validation") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.predictions) as f: + data = json.load(f) + + entries = data["predictions"] + mape_fit = data["aggregate"]["mape_fit_set"] + mape_held = data["aggregate"]["mape_held_out"] + + meas = np.array([e["measured_median_seconds"] * 1e3 for e in entries]) + pred = np.array([e["predicted_seconds"] * 1e3 for e in entries]) + + # Color groups + if args.color_by == "split": + groups = { + "Fit set": [i for i, e in enumerate(entries) if e["in_fit_set"]], + "Held-out": [i for i, e in enumerate(entries) if not e["in_fit_set"]], + } + palette = {"Fit set": "#1f77b4", "Held-out": "#d62728"} + elif args.color_by == "graph": + graph_types = sorted(set(e["config"].get("graph", "unknown") for e in entries)) + palette = dict(zip(graph_types, ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd"])) + groups = {g: [i for i, e in enumerate(entries) + if e["config"].get("graph", "unknown") == g] + for g in graph_types} + else: # world_size + ws_vals = sorted(set(e["config"].get("world_size", 1) for e in entries)) + colors = plt.cm.viridis(np.linspace(0, 1, len(ws_vals))) + palette = {f"K={w}": c for w, c in zip(ws_vals, colors)} + groups = {f"K={w}": [i for i, e in enumerate(entries) + if e["config"].get("world_size", 1) == w] + for w in ws_vals} + + fig, ax = plt.subplots(figsize=(4.5, 4.5)) + + all_vals = np.concatenate([meas, pred]) + lo, hi = all_vals.min() * 0.9, all_vals.max() * 1.1 + ax.plot([lo, hi], [lo, hi], "k--", linewidth=0.8, label="Ideal (y=x)") + + # 10% error bands + ax.fill_between([lo, hi], [lo * 0.9, hi * 0.9], [lo * 1.1, hi * 1.1], + alpha=0.08, color="gray") + + for label, idxs in groups.items(): + if not idxs: + continue + color = palette[label] + ax.scatter(meas[idxs], pred[idxs], s=22, color=color, + alpha=0.85, label=label, zorder=3) + + ax.set_xlim(lo, hi) + ax.set_ylim(lo, hi) + ax.set_xlabel("Measured $T_{\\mathrm{layer}}$ (ms)") + ax.set_ylabel("Predicted $T_{\\mathrm{layer}}$ (ms)") + ax.set_title("Cost Model Validation: Predicted vs. Measured", fontsize=9) + + # Annotate MAPE + mape_text = ( + f"Fit MAPE = {mape_fit*100:.1f}%\n" + f"Held-out MAPE = {mape_held*100:.1f}%" + ) + ax.text(0.04, 0.96, mape_text, transform=ax.transAxes, + verticalalignment="top", fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)) + + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, linestyle=":", linewidth=0.4) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_validation] Saved {out}.pdf and {out}.png") + print( + "Caption: Predicted vs. measured layer time $T_{\\mathrm{layer}}$ for all " + "benchmarked configurations. Dashed diagonal: perfect prediction. " + "Shaded band: $\\pm10\\%$ error region. " + f"In-sample MAPE = {mape_fit*100:.1f}\\%, " + f"held-out MAPE = {mape_held*100:.1f}\\%. " + "Colors distinguish " + + ("fit-set vs. held-out configurations." if args.color_by == "split" + else f"configurations by {args.color_by}.") + ) + + +if __name__ == "__main__": + main() From 060db3ec265d0c7f030756366b9684f254245b4f Mon Sep 17 00:00:00 2001 From: Shehtab Date: Mon, 13 Apr 2026 14:44:41 -0400 Subject: [PATCH 4/7] Update the gather fit to use non-linear fit --- .../analysis/fit_primitives.py | 96 ++++++++++++++----- 1 file changed, 71 insertions(+), 25 deletions(-) diff --git a/experiments/cost_model_benchmarks/analysis/fit_primitives.py b/experiments/cost_model_benchmarks/analysis/fit_primitives.py index 02c9233..a18ec67 100644 --- a/experiments/cost_model_benchmarks/analysis/fit_primitives.py +++ b/experiments/cost_model_benchmarks/analysis/fit_primitives.py @@ -30,12 +30,14 @@ import numpy as np from scipy import stats as sp_stats +from scipy.optimize import curve_fit # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def load_json_files(paths: list) -> list: records = [] for p in paths: @@ -70,6 +72,7 @@ def linear_fit(x: np.ndarray, y: np.ndarray): # Network fit: T = t_L + bytes / B # --------------------------------------------------------------------------- + def fit_network(records: list) -> dict: """Fit (t_L, B) from ping-pong records (one mode per call).""" bytes_arr = [] @@ -103,6 +106,7 @@ def fit_network(records: list) -> dict: # Compute fit: T = coeff_V * |V| + coeff_E * |E| + intercept # --------------------------------------------------------------------------- + def fit_compute(records: list, timing_key: str = "forward_trials_seconds") -> dict: """Fit compute cost as a function of |V| and |E|. @@ -136,16 +140,19 @@ def fit_compute(records: list, timing_key: str = "forward_trials_seconds") -> di # --------------------------------------------------------------------------- -# Gather fit: T = intercept + bytes / B_gather +# Gather fit: T = intercept + max(overhead, overhead + bytes / B_gather) # --------------------------------------------------------------------------- + def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict: k_arr, T_arr, F_arr = [], [], [] for rec in records: F = rec["config"]["feature_dim"] for meas in rec["measurements"]: k = meas["params"]["k"] + t_med = median_of_trials(meas[timing_key]) + k_arr.append(k) T_arr.append(t_med) F_arr.append(F) @@ -153,14 +160,42 @@ def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict k_arr = np.array(k_arr, dtype=float) F_arr = np.array(F_arr, dtype=float) T_arr = np.array(T_arr, dtype=float) - bytes_arr = k_arr * F_arr * 4.0 # float32 - fit = linear_fit(bytes_arr, T_arr) - bandwidth = 1.0 / fit["slope"] if fit["slope"] > 0 else float("nan") + + # 1. Define the piecewise model for curve_fit + def time_model(b, overhead, inv_bandwidth): + # Time is bounded by a constant overhead until bandwidth saturation is reached + return np.maximum(overhead, b * inv_bandwidth) + + # 2. Provide sensible initial guesses (p0) to help the optimizer + min_T = np.min(T_arr) + slope_guess = (np.max(T_arr) - min_T) / (np.max(bytes_arr) + 1e-9) + p0 = [min_T, slope_guess] + + # 3. Fit the curve + try: + # Bounds ensure overhead and time-per-byte are strictly positive + popt, _ = curve_fit(time_model, bytes_arr, T_arr, p0=p0, bounds=(0, np.inf)) + launch_overhead = popt[0] + inv_bandwidth = popt[1] + except Exception: + raise RuntimeError("Curve fit failed for gather primitive") + + bandwidth = 1.0 / inv_bandwidth if inv_bandwidth > 0 else float("nan") + + # 4. Calculate R-squared manually (curve_fit doesn't return it) + if not np.isnan(launch_overhead): + T_pred = time_model(bytes_arr, launch_overhead, inv_bandwidth) + ss_res = np.sum((T_arr - T_pred) ** 2) + ss_tot = np.sum((T_arr - np.mean(T_arr)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else float("nan") + else: + r_squared = float("nan") + return { "bandwidth_bytes_per_sec": bandwidth, - "intercept_seconds": fit["intercept"], - "r_squared": fit["r_squared"], + "intercept_seconds": launch_overhead, + "r_squared": r_squared, "_num_points": len(T_arr), } @@ -169,6 +204,7 @@ def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict # Main # --------------------------------------------------------------------------- + def parse_args(): p = argparse.ArgumentParser(description="Fit cost-model primitive parameters") p.add_argument("--pingpong-intra", nargs="+", default=[], metavar="FILE") @@ -191,15 +227,19 @@ def main(): if args.pingpong_intra: recs = load_json_files(args.pingpong_intra) net["intra"] = fit_network(recs) - print(f"[network/intra] B={net['intra']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " - f"t_L={net['intra']['latency_seconds']*1e6:.2f} µs " - f"R²={net['intra']['r_squared']:.4f}") + print( + f"[network/intra] B={net['intra']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"t_L={net['intra']['latency_seconds']*1e6:.2f} µs " + f"R²={net['intra']['r_squared']:.4f}" + ) if args.pingpong_inter: recs = load_json_files(args.pingpong_inter) net["inter"] = fit_network(recs) - print(f"[network/inter] B={net['inter']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " - f"t_L={net['inter']['latency_seconds']*1e6:.2f} µs " - f"R²={net['inter']['r_squared']:.4f}") + print( + f"[network/inter] B={net['inter']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"t_L={net['inter']['latency_seconds']*1e6:.2f} µs " + f"R²={net['inter']['r_squared']:.4f}" + ) result["network"] = net # Compute @@ -210,37 +250,43 @@ def main(): "forward": fit_compute(recs, "forward_trials_seconds"), "backward": fit_compute(recs, "backward_trials_seconds"), } - print(f"[compute/gcn] coeff_V={comp['gcn']['forward']['coeff_V']:.3e} " - f"coeff_E={comp['gcn']['forward']['coeff_E']:.3e} " - f"R²={comp['gcn']['forward']['r_squared']:.4f}") + print( + f"[compute/gcn] coeff_V={comp['gcn']['forward']['coeff_V']:.3e} " + f"coeff_E={comp['gcn']['forward']['coeff_E']:.3e} " + f"R²={comp['gcn']['forward']['r_squared']:.4f}" + ) if args.compute_edge: recs = load_json_files(args.compute_edge) comp["edge"] = { "forward": fit_compute(recs, "forward_trials_seconds"), "backward": fit_compute(recs, "backward_trials_seconds"), } - print(f"[compute/edge] coeff_V={comp['edge']['forward']['coeff_V']:.3e} " - f"coeff_E={comp['edge']['forward']['coeff_E']:.3e} " - f"R²={comp['edge']['forward']['r_squared']:.4f}") + print( + f"[compute/edge] coeff_V={comp['edge']['forward']['coeff_V']:.3e} " + f"coeff_E={comp['edge']['forward']['coeff_E']:.3e} " + f"R²={comp['edge']['forward']['r_squared']:.4f}" + ) result["compute"] = comp # Gather gath = {} for dist_name, files_attr in [ ("contiguous", "gather_contiguous"), - ("clustered", "gather_clustered"), - ("random", "gather_random"), + ("clustered", "gather_clustered"), + ("random", "gather_random"), ]: files = getattr(args, files_attr.replace("-", "_")) if files: recs = load_json_files(files) gath[dist_name] = { - "gather": fit_gather(recs, "gather_trials_seconds"), - "scatter_add": fit_gather(recs, "scatter_add_trials_seconds"), + "gather": fit_gather(recs, "gather_trials_seconds"), + "scatter_add": fit_gather(recs, "scatter_add_trials_seconds"), } - print(f"[gather/{dist_name}] " - f"B_gather={gath[dist_name]['gather']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " - f"R²={gath[dist_name]['gather']['r_squared']:.4f}") + print( + f"[gather/{dist_name}] " + f"B_gather={gath[dist_name]['gather']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"R²={gath[dist_name]['gather']['r_squared']:.4f}" + ) result["gather"] = gath # Write From ecddcd60be3a04bcae0606817bf9391db580a1b2 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Fri, 17 Apr 2026 13:24:13 -0400 Subject: [PATCH 5/7] Decompose common benchmark funcs --- .../analysis/fit_primitives.py | 96 ++++-- .../benchmarks/bench_end_to_end.py | 300 +++--------------- .../benchmarks/graph_data_common.py | 210 ++++++++++++ .../benchmarks/nn_layer_common.py | 41 +++ .../visualization/plot_gather.py | 151 +++++++-- 5 files changed, 494 insertions(+), 304 deletions(-) create mode 100644 experiments/cost_model_benchmarks/benchmarks/graph_data_common.py create mode 100644 experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py diff --git a/experiments/cost_model_benchmarks/analysis/fit_primitives.py b/experiments/cost_model_benchmarks/analysis/fit_primitives.py index a18ec67..6211f94 100644 --- a/experiments/cost_model_benchmarks/analysis/fit_primitives.py +++ b/experiments/cost_model_benchmarks/analysis/fit_primitives.py @@ -31,6 +31,7 @@ import numpy as np from scipy import stats as sp_stats from scipy.optimize import curve_fit +from scipy.special import expit # --------------------------------------------------------------------------- @@ -140,7 +141,7 @@ def fit_compute(records: list, timing_key: str = "forward_trials_seconds") -> di # --------------------------------------------------------------------------- -# Gather fit: T = intercept + max(overhead, overhead + bytes / B_gather) +# Gather fit: T = intercept + max(overhead, bytes / B_gather) # --------------------------------------------------------------------------- @@ -150,9 +151,7 @@ def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict F = rec["config"]["feature_dim"] for meas in rec["measurements"]: k = meas["params"]["k"] - t_med = median_of_trials(meas[timing_key]) - k_arr.append(k) T_arr.append(t_med) F_arr.append(F) @@ -160,43 +159,86 @@ def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict k_arr = np.array(k_arr, dtype=float) F_arr = np.array(F_arr, dtype=float) T_arr = np.array(T_arr, dtype=float) + bytes_arr = k_arr * F_arr * 4.0 # float32 - # 1. Define the piecewise model for curve_fit - def time_model(b, overhead, inv_bandwidth): - # Time is bounded by a constant overhead until bandwidth saturation is reached - return np.maximum(overhead, b * inv_bandwidth) + # 1. The Piecewise Linear Model (No Logarithms) + def time_model(b, overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh): + # 1. Bucket the bytes into their respective physical regimes + # Bytes processed exclusively at L2 speeds + bytes_L2 = np.clip(b, 0, L2_thresh) + # Bytes processed exclusively at HBM speeds + bytes_HBM = np.maximum(0, b - HBM_thresh) + + # 2. Apply the specific bandwidth (slope) to each bucket + t_mem = (bytes_L2 * inv_bw_L2) + (bytes_HBM * inv_bw_HBM) - # 2. Provide sensible initial guesses (p0) to help the optimizer - min_T = np.min(T_arr) - slope_guess = (np.max(T_arr) - min_T) / (np.max(bytes_arr) + 1e-9) - p0 = [min_T, slope_guess] + # 3. Floor the total time by the kernel launch overhead + return np.maximum(overhead, t_mem) + + # 2. Strategic initial guesses + min_T, max_T = float(np.min(T_arr)), float(np.max(T_arr)) + max_b = float(np.max(bytes_arr)) + + # The asymptotic slope is the difference in max/min time over max bytes + inv_bw_HBM_guess = (max_T - min_T) / (max_b + 1e-9) + + p0 = [ + min_T, # overhead + inv_bw_HBM_guess * 0.3, # inv_bw_L2 + inv_bw_HBM_guess, # inv_bw_HBM + max_b * 0.00001, # L2_thresh + max_b * 0.001, # HBM_thresh + ] - # 3. Fit the curve try: - # Bounds ensure overhead and time-per-byte are strictly positive - popt, _ = curve_fit(time_model, bytes_arr, T_arr, p0=p0, bounds=(0, np.inf)) - launch_overhead = popt[0] - inv_bandwidth = popt[1] - except Exception: - raise RuntimeError("Curve fit failed for gather primitive") - - bandwidth = 1.0 / inv_bandwidth if inv_bandwidth > 0 else float("nan") - - # 4. Calculate R-squared manually (curve_fit doesn't return it) - if not np.isnan(launch_overhead): - T_pred = time_model(bytes_arr, launch_overhead, inv_bandwidth) + bounds = ([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, max_b, max_b]) + + # 3. The Magic Fix: sigma=T_arr weights the fit by relative error + popt, _ = curve_fit( + time_model, + bytes_arr, + T_arr, + p0=p0, + bounds=bounds, + method="trf", + sigma=T_arr, + absolute_sigma=False, + ) + overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh = popt + + except Exception as e: + print(f"Fit failed: {e}") + overhead = inv_bw_L2 = inv_bw_HBM = L2_thresh = np.nan + + bw_HBM = 1.0 / inv_bw_HBM if inv_bw_HBM > 0 else float("nan") + bw_L2 = 1.0 / inv_bw_L2 if inv_bw_L2 > 0 else float("nan") + + # 3. Calculate linear R-squared in linear space + if not np.isnan(overhead): + T_pred = time_model( + bytes_arr, overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh + ) ss_res = np.sum((T_arr - T_pred) ** 2) ss_tot = np.sum((T_arr - np.mean(T_arr)) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else float("nan") else: r_squared = float("nan") + print( + f" Fitted gather: overhead={overhead*1e3:.3f} ms ", + f"BW_L2={bw_L2/1e9:.2f} GB/s BW_HBM={bw_HBM/1e9:.2f} GB/s " + f"L2_thresh={L2_thresh/1e6:.2f} MB HBM_thresh={HBM_thresh/1e6:.2f} MB " + f"R²={r_squared:.4f}", + ) + return { - "bandwidth_bytes_per_sec": bandwidth, - "intercept_seconds": launch_overhead, + "bandwidth_bytes_per_sec": bw_HBM, + "L2_bandwidth_bytes_per_sec": bw_L2, + "L2_inflection_bytes": L2_thresh, + "HBM_inflection_bytes": HBM_thresh, + "launch_overhead_seconds": overhead, "r_squared": r_squared, - "_num_points": len(T_arr), } diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py b/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py index 7da40eb..4cfe162 100644 --- a/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py +++ b/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py @@ -51,201 +51,14 @@ setup_distributed, write_result, ) - - -# =========================================================================== -# Synthetic graph generators -# =========================================================================== - -def gen_erdos_renyi(num_vertices: int, avg_degree: float, - rng: np.random.Generator) -> np.ndarray: - """Return edge array of shape [E, 2] (src, dst) for an Erdős-Rényi digraph.""" - num_edges = int(num_vertices * avg_degree) - src = rng.integers(0, num_vertices, size=num_edges) - dst = rng.integers(0, num_vertices, size=num_edges) - return np.stack([src, dst], axis=1) - - -def gen_sbm(num_vertices: int, avg_degree: float, inter_density: float, - rng: np.random.Generator) -> np.ndarray: - """Return edges for a Stochastic Block Model graph. - - Vertices are split into blocks of equal size (one per rank for convenience, - though the actual partitioning is a separate step). The ratio of - intra-block to inter-block edges is controlled by *inter_density*. - """ - world_size = dist.get_world_size() if dist.is_initialized() else 4 - block_size = num_vertices // world_size - edges = [] - target_edges = int(num_vertices * avg_degree) - - intra_edges = int(target_edges * (1.0 - inter_density)) - inter_edges = target_edges - intra_edges - - # Intra-block edges - for b in range(world_size): - start = b * block_size - end = start + block_size - n = intra_edges // world_size - s = rng.integers(start, end, size=n) - d = rng.integers(start, end, size=n) - edges.append(np.stack([s, d], axis=1)) - - # Inter-block edges - s = rng.integers(0, num_vertices, size=inter_edges) - d = rng.integers(0, num_vertices, size=inter_edges) - # Force cross-block by offsetting dst block - d_block = (s // block_size + 1 + rng.integers(0, world_size - 1, - size=inter_edges)) % world_size - d = d_block * block_size + rng.integers(0, block_size, size=inter_edges) - d = np.clip(d, 0, num_vertices - 1) - edges.append(np.stack([s, d], axis=1)) - - return np.concatenate(edges, axis=0) - - -# =========================================================================== -# Partitioners -# =========================================================================== - -def partition_random(num_vertices: int, world_size: int, - rng: np.random.Generator) -> np.ndarray: - return rng.integers(0, world_size, size=num_vertices).astype(np.int64) - - -def partition_balanced(num_vertices: int, world_size: int) -> np.ndarray: - return np.floor(np.arange(num_vertices) * world_size / num_vertices).astype(np.int64) - - -def partition_metis(num_vertices: int, world_size: int, - edges: np.ndarray) -> np.ndarray: - try: - import pymetis - except ImportError: - raise RuntimeError( - "pymetis is not installed. Install it with: pip install pymetis\n" - "Or use --partitioner random or --partitioner balanced." - ) - # Build adjacency list for pymetis - adj = [[] for _ in range(num_vertices)] - for s, d in edges: - adj[s].append(int(d)) - adj[d].append(int(s)) - _, membership = pymetis.part_graph(world_size, adjacency=adj) - return np.array(membership, dtype=np.int64) - - -# =========================================================================== -# Minimal halo-exchange infrastructure -# =========================================================================== - -def build_local_comm_pattern(edges: np.ndarray, assignment: np.ndarray, - rank: int, world_size: int): - """Compute the local communication pattern for this rank. - - Returns a dict with: - local_vertices — np.ndarray of vertex IDs owned by this rank - local_edge_index — torch.Tensor [2, E_local] with local vertex IDs - remapped so that 0..n_local-1 are owned vertices - and n_local..n_local+n_halo-1 are halo vertices - send_counts — list[int] of length world_size: vertices to send - recv_counts — list[int] of length world_size: vertices to recv - send_idx — local indices (into local_vertices) to send per rank - halo_global_ids — global vertex IDs of halo vertices, in recv order - intra_halo_size — halo vertices from same node (ranks sharing node) - inter_halo_size — halo vertices from remote nodes - ranks_per_node — int (derived from LOCAL_RANK / RANK relationship) - """ - local_mask = assignment == rank - local_vertices = np.where(local_mask)[0] - n_local = len(local_vertices) - - # Global -> local index map - g2l = {int(v): i for i, v in enumerate(local_vertices)} - - # Find edges where dst is local - local_dst_mask = np.isin(edges[:, 1], local_vertices) - local_edges = edges[local_dst_mask] - - # Halo: src vertices not owned by this rank - halo_src_mask = ~np.isin(local_edges[:, 0], local_vertices) - halo_global = np.unique(local_edges[halo_src_mask, 0]) - - # Group halo vertices by owning rank - halo_owners = assignment[halo_global] - recv_by_rank = [] - halo_order = [] - for r in range(world_size): - verts = halo_global[halo_owners == r] - recv_by_rank.append(verts) - halo_order.extend(verts.tolist()) - halo_order = np.array(halo_order, dtype=np.int64) - - # Global halo id -> local halo index - halo_g2l = {int(v): n_local + i for i, v in enumerate(halo_order)} - all_g2l = {**g2l, **halo_g2l} - - # Find which local vertices other ranks need (send pattern) - # We exchange recv_counts via all_to_all to learn send_counts - recv_counts = [len(rv) for rv in recv_by_rank] - - # Build send: for each rank r, which of our local vertices does r need? - # We do a global exchange of halo_global per rank - all_recv = [None] * world_size - dist.all_gather_object(all_recv, halo_order.tolist()) - - send_idx_by_rank = [] - for r in range(world_size): - needed = np.array(all_recv[r], dtype=np.int64) - owned_mask = assignment[needed] == rank if len(needed) > 0 else np.array([], dtype=bool) - owned = needed[owned_mask] if len(needed) > 0 else np.array([], dtype=np.int64) - # Map to local indices - local_idxs = np.array([g2l[int(v)] for v in owned], dtype=np.int64) - send_idx_by_rank.append(local_idxs) - - send_counts = [len(s) for s in send_idx_by_rank] - - # Remap edges to local indices - valid_edge_mask = np.array([ - (int(s) in all_g2l) and (int(d) in all_g2l) - for s, d in local_edges - ]) - local_edges_valid = local_edges[valid_edge_mask] - if len(local_edges_valid) > 0: - remapped_src = np.array([all_g2l[int(s)] for s in local_edges_valid[:, 0]]) - remapped_dst = np.array([all_g2l[int(d)] for d in local_edges_valid[:, 1]]) - edge_index = torch.tensor( - np.stack([remapped_src, remapped_dst], axis=0), dtype=torch.long - ) - else: - edge_index = torch.zeros((2, 0), dtype=torch.long) - - # Compute intra / inter halo sizes - ranks_per_node = int(os.environ.get("LOCAL_WORLD_SIZE", - os.environ.get("SLURM_NTASKS_PER_NODE", "4"))) - my_node = rank // ranks_per_node - intra_halo_size = 0 - inter_halo_size = 0 - for r, verts in enumerate(recv_by_rank): - peer_node = r // ranks_per_node - if peer_node == my_node: - intra_halo_size += len(verts) - else: - inter_halo_size += len(verts) - - return { - "local_vertices": local_vertices, - "n_local": n_local, - "n_halo": len(halo_order), - "edge_index": edge_index, - "send_counts": send_counts, - "recv_counts": recv_counts, - "send_idx_by_rank": send_idx_by_rank, - "halo_order": halo_order, - "intra_halo_size": intra_halo_size, - "inter_halo_size": inter_halo_size, - "ranks_per_node": ranks_per_node, - } +from benchmarks.graph_data_common import ( + gen_erdos_renyi, + gen_sbm, + partition_balanced, + partition_metis, + partition_random, +) +from benchmarks.nn_layer_common import GCNLayer, EdgeConditionedLayer class MinimalHaloExchange(torch.autograd.Function): @@ -260,14 +73,20 @@ def forward(ctx, x_local, send_idx_flat, send_counts, recv_counts, world_size): # Split by destination rank send_list = list(send_buf.split(send_counts, dim=0)) - recv_list = [torch.zeros(rc, x_local.shape[1], - dtype=x_local.dtype, device=x_local.device) - for rc in recv_counts] + recv_list = [ + torch.zeros( + rc, x_local.shape[1], dtype=x_local.dtype, device=x_local.device + ) + for rc in recv_counts + ] dist.all_to_all(recv_list, send_list) - recv_buf = torch.cat(recv_list, dim=0) if sum(recv_counts) > 0 else \ - torch.zeros(0, x_local.shape[1], device=x_local.device) + recv_buf = ( + torch.cat(recv_list, dim=0) + if sum(recv_counts) > 0 + else torch.zeros(0, x_local.shape[1], device=x_local.device) + ) ctx.save_for_backward(send_idx_flat) ctx.send_counts = send_counts @@ -281,7 +100,7 @@ def forward(ctx, x_local, send_idx_flat, send_counts, recv_counts, world_size): @staticmethod def backward(ctx, grad_recv): - send_idx_flat, = ctx.saved_tensors + (send_idx_flat,) = ctx.saved_tensors send_counts = ctx.send_counts recv_counts = ctx.recv_counts world_size = ctx.world_size @@ -290,8 +109,11 @@ def backward(ctx, grad_recv): device = ctx.device # Reverse: recv_counts become send_counts and vice versa - grad_recv_list = list(grad_recv.split(recv_counts, dim=0)) \ - if grad_recv.shape[0] > 0 else [torch.zeros(0, F, device=device)] * world_size + grad_recv_list = ( + list(grad_recv.split(recv_counts, dim=0)) + if grad_recv.shape[0] > 0 + else [torch.zeros(0, F, device=device)] * world_size + ) grad_send_list = [torch.zeros(sc, F, device=device) for sc in send_counts] dist.all_to_all(grad_send_list, grad_recv_list) @@ -308,59 +130,27 @@ def backward(ctx, grad_recv): return grad_x_local, None, None, None, None -# =========================================================================== -# GNN layers (same as bench_compute.py for consistency) -# =========================================================================== - -class GCNLayer(nn.Module): - def __init__(self, feature_dim: int): - super().__init__() - self.linear = nn.Linear(feature_dim, feature_dim, bias=False) - - def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: - src, dst = edge_index[0], edge_index[1] - n_local = x.shape[0] - # Only update local vertices (dst < n_local guard not needed since - # edge_index already restricts to local dst) - msg = self.linear(x[src]) - out = torch.zeros_like(x) - out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) - return out - - -class EdgeConditionedLayer(nn.Module): - def __init__(self, feature_dim: int): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(3 * feature_dim, feature_dim), - nn.ReLU(), - nn.Linear(feature_dim, feature_dim), - ) - - def forward(self, x: torch.Tensor, edge_index: torch.Tensor, - edge_attr: torch.Tensor) -> torch.Tensor: - src, dst = edge_index[0], edge_index[1] - msg = self.mlp(torch.cat([x[src], x[dst], edge_attr], dim=-1)) - out = torch.zeros_like(x) - out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) - return out - - # =========================================================================== # Main # =========================================================================== + def parse_args(): p = argparse.ArgumentParser(description="End-to-end halo exchange benchmark") p.add_argument("--graph", choices=["erdos_renyi", "sbm"], default="erdos_renyi") p.add_argument("--num-vertices", type=int, default=100_000) p.add_argument("--avg-degree", type=float, default=20.0) - p.add_argument("--sbm-inter-density", type=float, default=0.1, - help="Fraction of inter-block edges for SBM graphs") + p.add_argument( + "--sbm-inter-density", + type=float, + default=0.1, + help="Fraction of inter-block edges for SBM graphs", + ) p.add_argument("--feature-dim", type=int, default=128) p.add_argument("--model", choices=["gcn", "edge"], default="gcn") - p.add_argument("--partitioner", choices=["random", "balanced", "metis"], - default="balanced") + p.add_argument( + "--partitioner", choices=["random", "balanced", "metis"], default="balanced" + ) p.add_argument("--warmup", type=int, default=10) p.add_argument("--trials", type=int, default=50) p.add_argument("--output", type=str, required=True) @@ -380,8 +170,7 @@ def main(): if args.graph == "erdos_renyi": edges = gen_erdos_renyi(args.num_vertices, args.avg_degree, rng) else: - edges = gen_sbm(args.num_vertices, args.avg_degree, - args.sbm_inter_density, rng) + edges = gen_sbm(args.num_vertices, args.avg_degree, args.sbm_inter_density, rng) # --- Partition --- rng_part = np.random.default_rng(args.seed + 1) @@ -400,9 +189,13 @@ def main(): send_counts = pattern["send_counts"] recv_counts = pattern["recv_counts"] - send_idx_flat = torch.cat([ - torch.tensor(s, dtype=torch.long) for s in pattern["send_idx_by_rank"] - ]).to(device) if sum(send_counts) > 0 else torch.zeros(0, dtype=torch.long, device=device) + send_idx_flat = ( + torch.cat( + [torch.tensor(s, dtype=torch.long) for s in pattern["send_idx_by_rank"]] + ).to(device) + if sum(send_counts) > 0 + else torch.zeros(0, dtype=torch.long, device=device) + ) # --- Model --- if args.model == "gcn": @@ -413,8 +206,11 @@ def main(): # --- Synthetic local node features --- x_local = torch.randn(n_local, F, device=device, requires_grad=True) - edge_attr = torch.randn(edge_index.shape[1], F, device=device) \ - if args.model == "edge" else None + edge_attr = ( + torch.randn(edge_index.shape[1], F, device=device) + if args.model == "edge" + else None + ) # --- Timed forward + backward --- def one_layer(): diff --git a/experiments/cost_model_benchmarks/benchmarks/graph_data_common.py b/experiments/cost_model_benchmarks/benchmarks/graph_data_common.py new file mode 100644 index 0000000..15358e5 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/graph_data_common.py @@ -0,0 +1,210 @@ +import numpy as np +import torch.distributed as dist + + +# =========================================================================== +# Synthetic graph generators +# =========================================================================== + + +def gen_erdos_renyi( + num_vertices: int, avg_degree: float, rng: np.random.Generator +) -> np.ndarray: + """Return edge array of shape [E, 2] (src, dst) for an Erdős-Rényi digraph.""" + num_edges = int(num_vertices * avg_degree) + src = rng.integers(0, num_vertices, size=num_edges) + dst = rng.integers(0, num_vertices, size=num_edges) + return np.stack([src, dst], axis=1) + + +def gen_sbm( + num_vertices: int, avg_degree: float, inter_density: float, rng: np.random.Generator +) -> np.ndarray: + """Return edges for a Stochastic Block Model graph. + + Vertices are split into blocks of equal size (one per rank for convenience, + though the actual partitioning is a separate step). The ratio of + intra-block to inter-block edges is controlled by *inter_density*. + """ + world_size = dist.get_world_size() if dist.is_initialized() else 4 + block_size = num_vertices // world_size + edges = [] + target_edges = int(num_vertices * avg_degree) + + intra_edges = int(target_edges * (1.0 - inter_density)) + inter_edges = target_edges - intra_edges + + # Intra-block edges + for b in range(world_size): + start = b * block_size + end = start + block_size + n = intra_edges // world_size + s = rng.integers(start, end, size=n) + d = rng.integers(start, end, size=n) + edges.append(np.stack([s, d], axis=1)) + + # Inter-block edges + s = rng.integers(0, num_vertices, size=inter_edges) + d = rng.integers(0, num_vertices, size=inter_edges) + # Force cross-block by offsetting dst block + d_block = ( + s // block_size + 1 + rng.integers(0, world_size - 1, size=inter_edges) + ) % world_size + d = d_block * block_size + rng.integers(0, block_size, size=inter_edges) + d = np.clip(d, 0, num_vertices - 1) + edges.append(np.stack([s, d], axis=1)) + + return np.concatenate(edges, axis=0) + + +# =========================================================================== +# Partitioners +# =========================================================================== + + +def partition_random( + num_vertices: int, world_size: int, rng: np.random.Generator +) -> np.ndarray: + return rng.integers(0, world_size, size=num_vertices).astype(np.int64) + + +def partition_balanced(num_vertices: int, world_size: int) -> np.ndarray: + return np.floor(np.arange(num_vertices) * world_size / num_vertices).astype( + np.int64 + ) + + +def partition_metis( + num_vertices: int, world_size: int, edges: np.ndarray +) -> np.ndarray: + try: + import pymetis + except ImportError: + raise RuntimeError( + "pymetis is not installed. Install it with: pip install pymetis\n" + "Or use --partitioner random or --partitioner balanced." + ) + # Build adjacency list for pymetis + adj = [[] for _ in range(num_vertices)] + for s, d in edges: + adj[s].append(int(d)) + adj[d].append(int(s)) + _, membership = pymetis.part_graph(world_size, adjacency=adj) + return np.array(membership, dtype=np.int64) + + +# =========================================================================== +# Minimal halo-exchange infrastructure +# =========================================================================== + + +def build_local_comm_pattern( + edges: np.ndarray, assignment: np.ndarray, rank: int, world_size: int +): + """Compute the local communication pattern for this rank. + + Returns a CommunicationPattern object with: + local_vertices — np.ndarray of vertex IDs owned by this rank + local_edge_index — torch.Tensor [2, E_local] with local vertex IDs + remapped so that 0..n_local-1 are owned vertices + and n_local..n_local+n_halo-1 are halo vertices + send_counts — list[int] of length world_size: vertices to send + recv_counts — list[int] of length world_size: vertices to recv + send_idx — local indices (into local_vertices) to send per rank + halo_global_ids — global vertex IDs of halo vertices, in recv order + intra_halo_size — halo vertices from same node (ranks sharing node) + inter_halo_size — halo vertices from remote nodes + ranks_per_node — int (derived from LOCAL_RANK / RANK relationship) + """ + local_mask = assignment == rank + local_vertices = np.where(local_mask)[0] + n_local = len(local_vertices) + + # Global -> local index map + g2l = {int(v): i for i, v in enumerate(local_vertices)} + + # Find edges where dst is local + local_dst_mask = np.isin(edges[:, 1], local_vertices) + local_edges = edges[local_dst_mask] + + # Halo: src vertices not owned by this rank + halo_src_mask = ~np.isin(local_edges[:, 0], local_vertices) + halo_global = np.unique(local_edges[halo_src_mask, 0]) + + # Group halo vertices by owning rank + halo_owners = assignment[halo_global] + recv_by_rank = [] + halo_order = [] + for r in range(world_size): + verts = halo_global[halo_owners == r] + recv_by_rank.append(verts) + halo_order.extend(verts.tolist()) + halo_order = np.array(halo_order, dtype=np.int64) + + # Global halo id -> local halo index + halo_g2l = {int(v): n_local + i for i, v in enumerate(halo_order)} + all_g2l = {**g2l, **halo_g2l} + + # Find which local vertices other ranks need (send pattern) + # We exchange recv_counts via all_to_all to learn send_counts + recv_counts = [len(rv) for rv in recv_by_rank] + + # Build send: for each rank r, which of our local vertices does r need? + # We do a global exchange of halo_global per rank + all_recv = [None] * world_size + dist.all_gather_object(all_recv, halo_order.tolist()) + + send_idx_by_rank = [] + for r in range(world_size): + needed = np.array(all_recv[r], dtype=np.int64) + owned_mask = ( + assignment[needed] == rank if len(needed) > 0 else np.array([], dtype=bool) + ) + owned = needed[owned_mask] if len(needed) > 0 else np.array([], dtype=np.int64) + # Map to local indices + local_idxs = np.array([g2l[int(v)] for v in owned], dtype=np.int64) + send_idx_by_rank.append(local_idxs) + + send_counts = [len(s) for s in send_idx_by_rank] + + # Remap edges to local indices + valid_edge_mask = np.array( + [(int(s) in all_g2l) and (int(d) in all_g2l) for s, d in local_edges] + ) + local_edges_valid = local_edges[valid_edge_mask] + if len(local_edges_valid) > 0: + remapped_src = np.array([all_g2l[int(s)] for s in local_edges_valid[:, 0]]) + remapped_dst = np.array([all_g2l[int(d)] for d in local_edges_valid[:, 1]]) + edge_index = torch.tensor( + np.stack([remapped_src, remapped_dst], axis=0), dtype=torch.long + ) + else: + edge_index = torch.zeros((2, 0), dtype=torch.long) + + # Compute intra / inter halo sizes + ranks_per_node = int( + os.environ.get("LOCAL_WORLD_SIZE", os.environ.get("SLURM_NTASKS_PER_NODE", "4")) + ) + my_node = rank // ranks_per_node + intra_halo_size = 0 + inter_halo_size = 0 + for r, verts in enumerate(recv_by_rank): + peer_node = r // ranks_per_node + if peer_node == my_node: + intra_halo_size += len(verts) + else: + inter_halo_size += len(verts) + + return { + "local_vertices": local_vertices, + "n_local": n_local, + "n_halo": len(halo_order), + "edge_index": edge_index, + "send_counts": send_counts, + "recv_counts": recv_counts, + "send_idx_by_rank": send_idx_by_rank, + "halo_order": halo_order, + "intra_halo_size": intra_halo_size, + "inter_halo_size": inter_halo_size, + "ranks_per_node": ranks_per_node, + } diff --git a/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py b/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py new file mode 100644 index 0000000..893d268 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + +# =========================================================================== +# GNN layers (same as bench_compute.py for consistency) +# =========================================================================== + + +class GCNLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.linear = nn.Linear(feature_dim, feature_dim, bias=False) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: + src, dst = edge_index[0], edge_index[1] + n_local = x.shape[0] + # Only update local vertices (dst < n_local guard not needed since + # edge_index already restricts to local dst) + msg = self.linear(x[src]) + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +class EdgeConditionedLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(3 * feature_dim, feature_dim), + nn.ReLU(), + nn.Linear(feature_dim, feature_dim), + ) + + def forward( + self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor + ) -> torch.Tensor: + src, dst = edge_index[0], edge_index[1] + msg = self.mlp(torch.cat([x[src], x[dst], edge_attr], dim=-1)) + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out diff --git a/experiments/cost_model_benchmarks/visualization/plot_gather.py b/experiments/cost_model_benchmarks/visualization/plot_gather.py index cddfdc4..f33b193 100644 --- a/experiments/cost_model_benchmarks/visualization/plot_gather.py +++ b/experiments/cost_model_benchmarks/visualization/plot_gather.py @@ -19,19 +19,68 @@ import numpy as np import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt COLORS = { "contiguous": "#1f77b4", - "clustered": "#2ca02c", - "random": "#d62728", + "clustered": "#2ca02c", + "random": "#d62728", } -plt.rcParams.update({ - "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, - "xtick.labelsize": 8, "ytick.labelsize": 8, - "figure.dpi": 300, "text.usetex": False, -}) +plt.rcParams.update( + { + "font.size": 9, + "axes.labelsize": 9, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "figure.dpi": 300, + "text.usetex": False, + } +) + + +def fitted_gather( + min_val, + max_val, + feature_dim, + hbm_bandwidth_bytes_per_sec, + l2_bandwidth_bytes_per_sec, + launch_overhead_seconds, + L2_thresh, + HBM_thresh, +): + """Return a function that models gather time as a function of k.""" + + def time_model(b, overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh): + # 1. Bucket the bytes into their respective physical regimes + # Bytes processed exclusively at L2 speeds + bytes_L2 = np.clip(b, 0, L2_thresh) + # Bytes processed exclusively at HBM speeds + bytes_HBM = np.maximum(0, b - HBM_thresh) + + # 2. Apply the specific bandwidth (slope) to each bucket + t_mem = (bytes_L2 * inv_bw_L2) + (bytes_HBM * inv_bw_HBM) + + # 3. Floor the total time by the kernel launch overhead + return np.maximum(overhead, t_mem) + + x = np.linspace(min_val, max_val, 100) + inv_bw_L2 = 1.0 / l2_bandwidth_bytes_per_sec + inv_bw_HBM = 1.0 / hbm_bandwidth_bytes_per_sec + y = ( + time_model( + x * feature_dim * 4.0, + launch_overhead_seconds, + inv_bw_L2, + inv_bw_HBM, + L2_thresh, + HBM_thresh, + ) + * 1e3 + ) + return x, y def load_gather_file(paths: list, timing_key: str): @@ -42,14 +91,16 @@ def load_gather_file(paths: list, timing_key: str): data = json.load(f) for meas in data["measurements"]: trials = np.array(meas[timing_key]) - rows.append(( - meas["params"]["k"], - float(np.median(trials)), - float(np.percentile(trials, 25)), - float(np.percentile(trials, 75)), - )) + rows.append( + ( + meas["params"]["k"], + float(np.median(trials)), + float(np.percentile(trials, 25)), + float(np.percentile(trials, 75)), + ) + ) rows.sort(key=lambda r: r[0]) - k_arr = np.array([r[0] for r in rows]) + k_arr = np.array([r[0] for r in rows]) med_arr = np.array([r[1] for r in rows]) q25_arr = np.array([r[2] for r in rows]) q75_arr = np.array([r[3] for r in rows]) @@ -59,42 +110,92 @@ def load_gather_file(paths: list, timing_key: str): def parse_args(): p = argparse.ArgumentParser(description="Plot gather/scatter-add bandwidth") p.add_argument("--contiguous", nargs="+", default=[], metavar="FILE") - p.add_argument("--clustered", nargs="+", default=[], metavar="FILE") - p.add_argument("--random", nargs="+", default=[], metavar="FILE") - p.add_argument("--operation", choices=["gather", "scatter_add"], default="gather") - p.add_argument("--output", type=str, default="figures/gather") + p.add_argument("--clustered", nargs="+", default=[], metavar="FILE") + p.add_argument("--random", nargs="+", default=[], metavar="FILE") + p.add_argument("--operation", choices=["gather", "scatter_add"], default="gather") + p.add_argument("--fitted", type=str, default=None, metavar="FILE") + p.add_argument("--output", type=str, default="figures/gather") return p.parse_args() def main(): args = parse_args() timing_key = ( - "gather_trials_seconds" if args.operation == "gather" + "gather_trials_seconds" + if args.operation == "gather" else "scatter_add_trials_seconds" ) fig, ax = plt.subplots(figsize=(5, 3.5)) + if args.fitted: + with open(args.fitted) as f: + primitives = json.load(f) + min_k, max_k = 1e3, 1e9 for dist_name, files in [ ("contiguous", args.contiguous), - ("clustered", args.clustered), - ("random", args.random), + ("clustered", args.clustered), + ("random", args.random), ]: if not files: continue k, med, q25, q75 = load_gather_file(files, timing_key) + min_k = k[0] + max_k = k[-1] color = COLORS[dist_name] ax.errorbar( - k * 1e-6, med * 1e3, + k * 1e-6, + med * 1e3, yerr=[(med - q25) * 1e3, (q75 - med) * 1e3], - fmt="o-", markersize=3, color=color, linewidth=0.9, - capsize=2, elinewidth=0.8, label=dist_name.capitalize(), + fmt="o", + markersize=3, + color=color, + linewidth=0.9, + capsize=2, + elinewidth=0.8, + label=dist_name.capitalize(), ) + if args.fitted: + hbm_bw = primitives["gather"][dist_name]["gather"][ + "bandwidth_bytes_per_sec" + ] + l2_bw = primitives["gather"][dist_name]["gather"][ + "L2_bandwidth_bytes_per_sec" + ] + overhead = primitives["gather"][dist_name]["gather"][ + "launch_overhead_seconds" + ] + thresh = primitives["gather"][dist_name]["gather"]["L2_inflection_bytes"] + hbm_thresh = primitives["gather"][dist_name]["gather"][ + "HBM_inflection_bytes" + ] + x, y = fitted_gather( + min_k, + max_k, + feature_dim=512, + hbm_bandwidth_bytes_per_sec=hbm_bw, + l2_bandwidth_bytes_per_sec=l2_bw, + launch_overhead_seconds=overhead, + L2_thresh=thresh, + HBM_thresh=hbm_thresh, + ) + ax.plot( + x * 1e-6, + y, + "--", + color=color, + linewidth=1.2, + label=f"{dist_name.capitalize()}-Expected", + alpha=0.4, + ) ax.set_xscale("log") ax.set_yscale("log") - op_label = "Gather $x[\\mathrm{idx}]$" if args.operation == "gather" \ + op_label = ( + "Gather $x[\\mathrm{idx}]$" + if args.operation == "gather" else "Scatter-add (backward)" + ) ax.set_xlabel("k (millions of rows gathered)") ax.set_ylabel(f"{op_label} time (ms)") ax.set_title(f"Buffer-Copy Bandwidth: {op_label}", fontsize=9) From a89e2db662bd485677463b595d165679bad66a84 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Sat, 18 Apr 2026 11:14:44 -0400 Subject: [PATCH 6/7] Add crossover benchmark to measure communication compute tradeoff --- .../benchmarks/bench_crossover.py | 533 ++++++++++++++++++ 1 file changed, 533 insertions(+) create mode 100644 experiments/cost_model_benchmarks/benchmarks/bench_crossover.py diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py b/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py new file mode 100644 index 0000000..d6bc829 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py @@ -0,0 +1,533 @@ +"""Benchmark 2.2 — Multi-GPU Crossover Point. + +Sweeps graph size to identify the point at which distributing computation +across K GPUs overcomes the overhead of halo-exchange communication. + +For each graph size N in the sweep: + + * **Single-GPU baseline**: rank 0 runs forward+backward on the *complete* + graph on one GPU. Measures raw compute cost T_comp(N). + * **Multi-GPU distributed**: all K ranks execute partitioned + forward+backward with halo exchange. Measures + T_comp(N/K) + T_comm. + +The *crossover* N* is the smallest graph size where +T_single(N) > T_multi(K, N), i.e. where distributing first becomes +beneficial. + +When ``--no-dist`` is set the script runs the single-GPU sweep only +(useful for characterising baseline compute without a multi-GPU +allocation). + +Synthetic graphs: + * ``erdos_renyi`` — Erdős-Rényi with ``--avg-degree`` expected degree + * ``sbm`` — Stochastic Block Model; ``--sbm-inter-density`` + controls the fraction of inter-block edges + +Partitioners: + * ``random`` — assign each vertex to a uniformly random rank + * ``balanced`` — contiguous vertex blocks of equal size + * ``metis`` — balanced k-way via pymetis (skipped if not installed) + +Usage — distributed (torchrun):: + + torchrun --nnodes 2 --nproc_per_node 4 \\ + -m benchmarks.benchmark_crossover \\ + --graph erdos_renyi \\ + --graph-sizes 10000,100000,1000000,10000000 \\ + --avg-degree 20 --feature-dim 128 --model gcn \\ + --partitioner balanced --warmup 10 --trials 50 \\ + --output data/crossover_K8_F128_er_bal.json --seed 42 + +Usage — single-GPU baseline only:: + + python -m benchmarks.benchmark_crossover --no-dist \\ + --graph erdos_renyi \\ + --graph-sizes 10000,100000,1000000,10000000 \\ + --avg-degree 20 --feature-dim 128 --model gcn \\ + --warmup 10 --trials 50 \\ + --output data/crossover_single_F128.json --seed 42 +""" + +import argparse +import os + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + setup_distributed, + write_result, +) +from benchmarks.graph_data_common import ( + gen_erdos_renyi, + gen_sbm, + partition_balanced, + partition_metis, + partition_random, +) +from benchmarks.nn_layer_common import GCNLayer, EdgeConditionedLayer + +from DGraph.distributed import ( + HaloExchange, + CommunicationPattern, + build_communication_pattern, +) +from DGraph import Communicator + + +# --------------------------------------------------------------------------- +# Argument parsing +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="Multi-GPU crossover point benchmark") + p.add_argument("--graph", choices=["erdos_renyi", "sbm"], default="erdos_renyi") + p.add_argument( + "--graph-sizes", + type=str, + default="10000,100000,1000000,10000000", + help="Comma-separated list of num_vertices values to sweep", + ) + p.add_argument("--avg-degree", type=float, default=20.0) + p.add_argument( + "--sbm-inter-density", + type=float, + default=0.1, + help="Fraction of inter-block edges for SBM graphs", + ) + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--model", choices=["gcn", "edge"], default="gcn") + p.add_argument( + "--partitioner", choices=["random", "balanced", "metis"], default="balanced" + ) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--no-dist", + action="store_true", + help="Run single-GPU sweep only (no distributed setup required)", + ) + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _gen_graph(graph_type, num_vertices, avg_degree, sbm_inter_density, seed): + """Generate a synthetic graph reproducibly for a given graph size.""" + rng = np.random.default_rng(seed) + if graph_type == "erdos_renyi": + return gen_erdos_renyi(num_vertices, avg_degree, rng) + else: + return gen_sbm(num_vertices, avg_degree, sbm_inter_density, rng) + + +def _intra_inter_halo( + comm_pattern: CommunicationPattern, ranks_per_node: int +) -> tuple: + """Return (intra_halo_vertices, inter_halo_vertices) from recv_offset.""" + rank = comm_pattern.rank + my_node = rank // ranks_per_node + recv_counts = ( + comm_pattern.recv_offset[1:] - comm_pattern.recv_offset[:-1] + ).tolist() + intra = 0 + inter = 0 + for r, count in enumerate(recv_counts): + if r == rank: + continue + if (r // ranks_per_node) == my_node: + intra += int(count) + else: + inter += int(count) + return intra, inter + + +def _build_single_gpu_tensors(num_vertices, edges_np, F, model, device): + """Allocate full-graph tensors on *device*. May raise cuda.OutOfMemoryError.""" + # GCNLayer / EdgeConditionedLayer expect edge_index as [2, E] + edge_t = torch.from_numpy(edges_np.T.copy()).long().to(device) + x = torch.randn(num_vertices, F, device=device, requires_grad=True) + if model == "gcn": + layer = GCNLayer(F).to(device) + edge_attr = None + else: + layer = EdgeConditionedLayer(F).to(device) + edge_attr = torch.randn(edges_np.shape[0], F, device=device) + layer.train() + return x, edge_t, layer, edge_attr + + +def _single_gpu_fn(x, edge_t, layer, edge_attr): + """Return a zero-argument closure for single-GPU forward+backward.""" + if edge_attr is None: + def fn(): + out = layer(x, edge_t) + out.sum().backward() + if x.grad is not None: + x.grad.zero_() + else: + def fn(): + out = layer(x, edge_t, edge_attr) + out.sum().backward() + if x.grad is not None: + x.grad.zero_() + return fn + + +def _multi_gpu_fn(x_local, comm_pattern, layer, halo_exchange, edge_attr, model): + """Return a zero-argument closure for distributed forward+backward. + + ``comm_pattern.local_edge_list`` has shape ``[E, 2]``; we transpose it + once to ``[2, E]`` as required by GCNLayer / EdgeConditionedLayer. + """ + edge_index = comm_pattern.local_edge_list.T.contiguous() # [2, E_local] + + if model == "gcn": + def fn(): + recv_buf = halo_exchange(x_local, comm_pattern) + x_aug = torch.cat([x_local, recv_buf], dim=0) + out = layer(x_aug, edge_index) + out.sum().backward() + if x_local.grad is not None: + x_local.grad.zero_() + else: + def fn(): + recv_buf = halo_exchange(x_local, comm_pattern) + x_aug = torch.cat([x_local, recv_buf], dim=0) + out = layer(x_aug, edge_index, edge_attr) + out.sum().backward() + if x_local.grad is not None: + x_local.grad.zero_() + return fn + + +# --------------------------------------------------------------------------- +# Single-GPU-only path +# --------------------------------------------------------------------------- + + +def run_no_dist(args, graph_sizes, F): + """Sweep graph sizes on a single GPU and return the measurement list.""" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + seed_everything(args.seed) + measurements = [] + + for num_vertices in graph_sizes: + # Unique seed per graph size so each topology is reproducible independently + graph_seed = args.seed + num_vertices + edges_np = _gen_graph( + args.graph, num_vertices, args.avg_degree, args.sbm_inter_density, graph_seed + ) + num_edges = int(edges_np.shape[0]) + + times_single = None + oom = False + try: + x, edge_t, layer, edge_attr = _build_single_gpu_tensors( + num_vertices, edges_np, F, args.model, device + ) + fn = _single_gpu_fn(x, edge_t, layer, edge_attr) + times_single = cuda_timed(fn, warmup=args.warmup, trials=args.trials) + # Free before next iteration + del x, edge_t, layer + if edge_attr is not None: + del edge_attr + torch.cuda.empty_cache() + except torch.cuda.OutOfMemoryError: + oom = True + print( + f"[crossover] OOM: N={num_vertices:,} exceeds single-GPU memory; " + "skipping and continuing" + ) + torch.cuda.empty_cache() + + med_str = ( + f"{1e3 * sorted(times_single)[len(times_single) // 2]:.2f} ms" + if times_single + else "OOM" + ) + print(f"[crossover] no-dist N={num_vertices:>12,} single={med_str}") + + measurements.append( + { + "params": { + "num_vertices": num_vertices, + "num_edges": num_edges, + "avg_degree": args.avg_degree, + "feature_dim": F, + "model": args.model, + "graph": args.graph, + }, + "single_gpu_trials_seconds": times_single, + "single_gpu_oom": oom, + "multi_gpu_trials_seconds_rank0": None, + "multi_gpu_trials_seconds_max": None, + "per_rank_stats": None, + "world_size": 1, + } + ) + + return measurements + + +# --------------------------------------------------------------------------- +# Distributed path +# --------------------------------------------------------------------------- + + +def run_distributed(args, graph_sizes, F, rank, world_size, local_rank): + """Run one crossover measurement per graph size using all K ranks.""" + device = torch.device(f"cuda:{local_rank}") + + comm = Communicator(backend="nccl") + halo_exchange = HaloExchange(comm=comm) + + ranks_per_node = int( + os.environ.get( + "LOCAL_WORLD_SIZE", os.environ.get("SLURM_NTASKS_PER_NODE", "4") + ) + ) + + measurements = [] + + for num_vertices in graph_sizes: + # All ranks generate *identical* graph topology (same seed per size) + graph_seed = args.seed + num_vertices + edges_np = _gen_graph( + args.graph, num_vertices, args.avg_degree, args.sbm_inter_density, graph_seed + ) + + # All ranks compute *identical* partition assignment + rng_part = np.random.default_rng(args.seed + 1 + num_vertices) + if args.partitioner == "random": + assignment_np = partition_random(num_vertices, world_size, rng_part) + elif args.partitioner == "balanced": + assignment_np = partition_balanced(num_vertices, world_size) + else: + assignment_np = partition_metis(num_vertices, world_size, edges_np) + + # Move to GPU for DGraph's build_communication_pattern (collective) + edges_t = torch.from_numpy(edges_np).long().to(device) # [E, 2] + partitioning_t = torch.from_numpy(assignment_np).long().to(device) # [V] + + # Collective: all ranks call this in sync (internally calls dist.all_gather) + comm_pattern = build_communication_pattern( + edges_t, partitioning_t, rank, world_size + ) + + n_local = comm_pattern.num_local_vertices + n_halo = comm_pattern.num_halo_vertices + n_local_edges = comm_pattern.local_edge_list.shape[0] + + x_local = torch.randn(n_local, F, device=device, requires_grad=True) + edge_attr_dist = ( + torch.randn(n_local_edges, F, device=device) + if args.model == "edge" + else None + ) + layer_dist = ( + GCNLayer(F).to(device) + if args.model == "gcn" + else EdgeConditionedLayer(F).to(device) + ) + layer_dist.train() + + fn_multi = _multi_gpu_fn( + x_local, comm_pattern, layer_dist, halo_exchange, edge_attr_dist, args.model + ) + + # ---- Time multi-GPU (all ranks participate) ---- + dist.barrier() + times_multi_local = cuda_timed(fn_multi, warmup=args.warmup, trials=args.trials) + dist.barrier() + + # Gather timings and partition stats from all ranks to rank 0 + all_times_multi = [None] * world_size + dist.all_gather_object(all_times_multi, times_multi_local) + + intra_halo, inter_halo = _intra_inter_halo(comm_pattern, ranks_per_node) + stats_local = { + "rank": rank, + "n_local": n_local, + "n_halo": n_halo, + "n_local_edges": n_local_edges, + "intra_halo_size": intra_halo, + "inter_halo_size": inter_halo, + "c_intra_bytes": intra_halo * F * 4, + "c_inter_bytes": inter_halo * F * 4, + "send_total": int(comm_pattern.send_offset[-1].item()), + "recv_total": int(comm_pattern.recv_offset[-1].item()), + "trials_seconds": times_multi_local, + } + all_stats = [None] * world_size + dist.all_gather_object(all_stats, stats_local) + + # Free GPU memory before the single-GPU run + del x_local, edges_t, partitioning_t, layer_dist + if edge_attr_dist is not None: + del edge_attr_dist + torch.cuda.empty_cache() + + # ---- Rank 0: time single-GPU baseline (non-collective) ---- + times_single = None + single_oom = False + if rank == 0: + edges_s = _gen_graph( + args.graph, + num_vertices, + args.avg_degree, + args.sbm_inter_density, + graph_seed, + ) + try: + x_s, edge_t_s, layer_s, edge_attr_s = _build_single_gpu_tensors( + num_vertices, edges_s, F, args.model, device + ) + fn_single = _single_gpu_fn(x_s, edge_t_s, layer_s, edge_attr_s) + times_single = cuda_timed( + fn_single, warmup=args.warmup, trials=args.trials + ) + del x_s, edge_t_s, layer_s + if edge_attr_s is not None: + del edge_attr_s + torch.cuda.empty_cache() + except torch.cuda.OutOfMemoryError: + single_oom = True + print( + f"[crossover] rank 0 OOM: N={num_vertices:,} exceeds single-GPU memory" + ) + torch.cuda.empty_cache() + + # Sync all ranks after rank 0 finishes its single-GPU benchmark + dist.barrier() + + if rank == 0: + n_trials = len(times_multi_local) + # Wall time = max latency across all ranks per trial + times_multi_max = [ + max(all_times_multi[r][i] for r in range(world_size)) + for i in range(n_trials) + ] + + med_multi = sorted(times_multi_max)[n_trials // 2] + if times_single: + med_single = sorted(times_single)[len(times_single) // 2] + speedup = med_single / med_multi if med_multi > 0 else float("nan") + else: + med_single = float("nan") + speedup = float("nan") + + print( + f"[crossover] K={world_size:>3} N={num_vertices:>12,} " + f"single={1e3*med_single:.2f}ms " + f"multi={1e3*med_multi:.2f}ms " + f"speedup={speedup:.2f}x" + ) + + measurements.append( + { + "params": { + "num_vertices": num_vertices, + "num_global_edges": int(edges_np.shape[0]), + "avg_degree": args.avg_degree, + "feature_dim": F, + "model": args.model, + "graph": args.graph, + "partitioner": args.partitioner, + "world_size": world_size, + "ranks_per_node": ranks_per_node, + }, + "single_gpu_trials_seconds": times_single, + "single_gpu_oom": single_oom, + "multi_gpu_trials_seconds_rank0": times_multi_local, + "multi_gpu_trials_seconds_max": times_multi_max, + "per_rank_stats": all_stats, + } + ) + + return measurements, ranks_per_node + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + args = parse_args() + graph_sizes = [int(s) for s in args.graph_sizes.split(",")] + F = args.feature_dim + + if args.no_dist: + measurements = run_no_dist(args, graph_sizes, F) + payload = { + "benchmark": "crossover", + "metadata": collect_metadata(), + "config": { + "graph": args.graph, + "graph_sizes": graph_sizes, + "avg_degree": args.avg_degree, + "sbm_inter_density": args.sbm_inter_density, + "feature_dim": F, + "model": args.model, + "partitioner": "none", + "world_size": 1, + "ranks_per_node": 1, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + "mode": "single_gpu_only", + }, + "measurements": measurements, + } + write_result(args.output, payload) + return + + # ---- Distributed run ---- + rank, world_size, local_rank = setup_distributed() + seed_everything(args.seed + rank) + + measurements, ranks_per_node = run_distributed( + args, graph_sizes, F, rank, world_size, local_rank + ) + + if rank == 0: + payload = { + "benchmark": "crossover", + "metadata": collect_metadata(), + "config": { + "graph": args.graph, + "graph_sizes": graph_sizes, + "avg_degree": args.avg_degree, + "sbm_inter_density": args.sbm_inter_density, + "feature_dim": F, + "model": args.model, + "partitioner": args.partitioner, + "world_size": world_size, + "ranks_per_node": ranks_per_node, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + "mode": "distributed", + }, + "measurements": measurements, + } + write_result(args.output, payload) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From 62f02487f3063ce476640398f97208d6b273c5f9 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 24 Apr 2026 21:04:15 -0700 Subject: [PATCH 7/7] Add running crossover benchmark and visualization --- DGraph/distributed/commInfo.py | 15 +- .../benchmarks/bench_crossover.py | 54 ++++-- .../benchmarks/nn_layer_common.py | 1 + .../visualization/plot_crossover.py | 183 ++++++++++++++++++ .../visualization/plot_flop_crossover.py | 140 ++++++++++++++ 5 files changed, 370 insertions(+), 23 deletions(-) create mode 100644 experiments/cost_model_benchmarks/visualization/plot_crossover.py create mode 100644 experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py diff --git a/DGraph/distributed/commInfo.py b/DGraph/distributed/commInfo.py index 0d5bd7e..edf36a3 100644 --- a/DGraph/distributed/commInfo.py +++ b/DGraph/distributed/commInfo.py @@ -47,23 +47,22 @@ def compute_halo_vertices( """ Computes halo vertices. Supports both homogeneous and bipartite/heterogeneous relations. """ - # Fallback for homogeneous graphs if dst_partitioning is None: dst_partitioning = src_partitioning src_rank = src_partitioning[edge_list[:, 0]] dst_rank = dst_partitioning[edge_list[:, 1]] - # Cross-rank mask: source is local, destination is remote - cross_mask = (src_rank == rank) & (dst_rank != rank) + # FIX: Cross-rank mask for pull model: source is remote, destination is local + cross_mask = (src_rank != rank) & (dst_rank == rank) - # Return unique destination vertex IDs from those edges - return torch.unique(edge_list[cross_mask, 1]) + # FIX: Return unique SOURCE vertex IDs from those edges + return torch.unique(edge_list[cross_mask, 0]) def compute_local_edge_list( global_edge_list: torch.Tensor, # [E, 2] - partitioning: torch.Tensor, # [V] + partitioning: torch.Tensor, # [V] (Acts as dst_partitioning) local_vertices_global: torch.Tensor, # [num_local] halo_vertices_global: torch.Tensor, # [num_halo] rank: int, @@ -72,8 +71,8 @@ def compute_local_edge_list( num_halo = halo_vertices_global.size(0) num_global = partitioning.size(0) - # Filter edges owned by this rank - local_edge_mask = partitioning[global_edge_list[:, 0]] == rank + # FIX: Filter edges where the DESTINATION is owned by this rank (Index 1) + local_edge_mask = partitioning[global_edge_list[:, 1]] == rank local_edges_global = global_edge_list[local_edge_mask] # Build inverse map: global_id -> local_idx via scatter diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py b/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py index d6bc829..c5c2922 100644 --- a/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py +++ b/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py @@ -31,7 +31,7 @@ Usage — distributed (torchrun):: - torchrun --nnodes 2 --nproc_per_node 4 \\ + torchrun --nnodes 1 --nproc_per_node 4 \\ -m benchmarks.benchmark_crossover \\ --graph erdos_renyi \\ --graph-sizes 10000,100000,1000000,10000000 \\ @@ -55,7 +55,7 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn as nn +import gc from benchmarks.common import ( collect_metadata, @@ -92,7 +92,7 @@ def parse_args(): p.add_argument( "--graph-sizes", type=str, - default="10000,100000,1000000,10000000", + default="100,200,400,800,1000,2000,4000,8000,10000,16000,32000,100000,200000,400000", help="Comma-separated list of num_vertices values to sweep", ) p.add_argument("--avg-degree", type=float, default=20.0) @@ -133,9 +133,7 @@ def _gen_graph(graph_type, num_vertices, avg_degree, sbm_inter_density, seed): return gen_sbm(num_vertices, avg_degree, sbm_inter_density, rng) -def _intra_inter_halo( - comm_pattern: CommunicationPattern, ranks_per_node: int -) -> tuple: +def _intra_inter_halo(comm_pattern: CommunicationPattern, ranks_per_node: int) -> tuple: """Return (intra_halo_vertices, inter_halo_vertices) from recv_offset.""" rank = comm_pattern.rank my_node = rank // ranks_per_node @@ -172,17 +170,21 @@ def _build_single_gpu_tensors(num_vertices, edges_np, F, model, device): def _single_gpu_fn(x, edge_t, layer, edge_attr): """Return a zero-argument closure for single-GPU forward+backward.""" if edge_attr is None: + def fn(): out = layer(x, edge_t) out.sum().backward() if x.grad is not None: x.grad.zero_() + else: + def fn(): out = layer(x, edge_t, edge_attr) out.sum().backward() if x.grad is not None: x.grad.zero_() + return fn @@ -195,14 +197,18 @@ def _multi_gpu_fn(x_local, comm_pattern, layer, halo_exchange, edge_attr, model) edge_index = comm_pattern.local_edge_list.T.contiguous() # [2, E_local] if model == "gcn": + def fn(): recv_buf = halo_exchange(x_local, comm_pattern) x_aug = torch.cat([x_local, recv_buf], dim=0) + out = layer(x_aug, edge_index) out.sum().backward() if x_local.grad is not None: x_local.grad.zero_() + else: + def fn(): recv_buf = halo_exchange(x_local, comm_pattern) x_aug = torch.cat([x_local, recv_buf], dim=0) @@ -210,6 +216,7 @@ def fn(): out.sum().backward() if x_local.grad is not None: x_local.grad.zero_() + return fn @@ -228,7 +235,11 @@ def run_no_dist(args, graph_sizes, F): # Unique seed per graph size so each topology is reproducible independently graph_seed = args.seed + num_vertices edges_np = _gen_graph( - args.graph, num_vertices, args.avg_degree, args.sbm_inter_density, graph_seed + args.graph, + num_vertices, + args.avg_degree, + args.sbm_inter_density, + graph_seed, ) num_edges = int(edges_np.shape[0]) @@ -292,21 +303,23 @@ def run_distributed(args, graph_sizes, F, rank, world_size, local_rank): device = torch.device(f"cuda:{local_rank}") comm = Communicator(backend="nccl") - halo_exchange = HaloExchange(comm=comm) ranks_per_node = int( - os.environ.get( - "LOCAL_WORLD_SIZE", os.environ.get("SLURM_NTASKS_PER_NODE", "4") - ) + os.environ.get("LOCAL_WORLD_SIZE", os.environ.get("SLURM_NTASKS_PER_NODE", "4")) ) measurements = [] for num_vertices in graph_sizes: + halo_exchange = HaloExchange(comm=comm) # All ranks generate *identical* graph topology (same seed per size) graph_seed = args.seed + num_vertices edges_np = _gen_graph( - args.graph, num_vertices, args.avg_degree, args.sbm_inter_density, graph_seed + args.graph, + num_vertices, + args.avg_degree, + args.sbm_inter_density, + graph_seed, ) # All ranks compute *identical* partition assignment @@ -319,7 +332,7 @@ def run_distributed(args, graph_sizes, F, rank, world_size, local_rank): assignment_np = partition_metis(num_vertices, world_size, edges_np) # Move to GPU for DGraph's build_communication_pattern (collective) - edges_t = torch.from_numpy(edges_np).long().to(device) # [E, 2] + edges_t = torch.from_numpy(edges_np).long().to(device) # [E, 2] partitioning_t = torch.from_numpy(assignment_np).long().to(device) # [V] # Collective: all ranks call this in sync (internally calls dist.all_gather) @@ -375,15 +388,26 @@ def run_distributed(args, graph_sizes, F, rank, world_size, local_rank): dist.all_gather_object(all_stats, stats_local) # Free GPU memory before the single-GPU run - del x_local, edges_t, partitioning_t, layer_dist + del ( + x_local, + edges_t, + partitioning_t, + layer_dist, + halo_exchange, + comm_pattern, + fn_multi, + ) if edge_attr_dist is not None: del edge_attr_dist + gc.collect() torch.cuda.empty_cache() # ---- Rank 0: time single-GPU baseline (non-collective) ---- times_single = None single_oom = False if rank == 0: + gc.collect() + torch.cuda.empty_cache() edges_s = _gen_graph( args.graph, num_vertices, @@ -399,7 +423,7 @@ def run_distributed(args, graph_sizes, F, rank, world_size, local_rank): times_single = cuda_timed( fn_single, warmup=args.warmup, trials=args.trials ) - del x_s, edge_t_s, layer_s + del x_s, edge_t_s, layer_s, fn_single if edge_attr_s is not None: del edge_attr_s torch.cuda.empty_cache() diff --git a/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py b/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py index 893d268..63a6af3 100644 --- a/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py +++ b/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.distributed as dist # =========================================================================== # GNN layers (same as bench_compute.py for consistency) diff --git a/experiments/cost_model_benchmarks/visualization/plot_crossover.py b/experiments/cost_model_benchmarks/visualization/plot_crossover.py new file mode 100644 index 0000000..cfa02cb --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_crossover.py @@ -0,0 +1,183 @@ +import numpy as np +import matplotlib.pyplot as plt +import json +from pathlib import Path +import argparse + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot Crossover Benchmark Results") + p.add_argument( + "--input", type=str, required=True, help="Path to benchmark JSON file" + ) + p.add_argument( + "--output", + type=str, + default="crossover_analysis.png", + help="Path to save the generated plot", + ) + return p.parse_args() + + +def plot_crossover_benchmark(payload, save_path="crossover_analysis.png"): + """ + Parses benchmark payload and plots Single GPU vs Multi-GPU execution times, + annotating the crossover point where distributed training becomes faster. + """ + # Sort measurements strictly by graph size (num_vertices) + measurements = sorted( + payload["measurements"], key=lambda x: x["params"]["num_vertices"] + ) + + vertices = [] + single_gpu_times = [] + multi_gpu_times = [] + single_gpu_oom_points = [] + + world_size = payload["config"]["world_size"] + model_name = payload["config"]["model"] + partitioner = payload["config"]["partitioner"] + feature_dim = payload["config"]["feature_dim"] + + for m in measurements: + v = m["params"]["num_global_edges"] + vertices.append(v) + + # Multi-GPU: Use the max time across ranks as it represents the true synchronous bottleneck + multi_time = np.median(m["multi_gpu_trials_seconds_max"]) + multi_gpu_times.append(multi_time) + + # Single-GPU: Handle OOM scenarios safely + if m.get("single_gpu_oom", False) or not m["single_gpu_trials_seconds"]: + single_gpu_times.append(np.nan) + single_gpu_oom_points.append(v) + else: + single_gpu_times.append(np.median(m["single_gpu_trials_seconds"])) + + vertices = np.array(vertices) + single_gpu_times = np.array(single_gpu_times) + multi_gpu_times = np.array(multi_gpu_times) + + # Initialize Plot + plt.figure(figsize=(10, 6), dpi=150) + plt.grid(True, which="both", ls="-", alpha=0.2) + + # Plot Valid Single GPU points + valid_mask = ~np.isnan(single_gpu_times) + plt.plot( + vertices[valid_mask], + single_gpu_times[valid_mask], + marker="o", + linestyle="-", + linewidth=2, + color="#1f77b4", + label="Single GPU", + ) + + # Plot Multi GPU points + plt.plot( + vertices, + multi_gpu_times, + marker="s", + linestyle="-", + linewidth=2, + color="#ff7f0e", + label=f"Distributed ({world_size} GPUs)", + ) + + # Annotate OOM boundaries + for oom_v in single_gpu_oom_points: + plt.axvline(x=oom_v, color="#d62728", linestyle="--", alpha=0.6) + plt.text( + oom_v * 1.02, + plt.ylim()[1] * 0.9, + "Single GPU OOM", + color="#d62728", + verticalalignment="top", + ) + + # Calculate and Annotate Crossover Point + crossover_found = False + for i in range(1, len(vertices)): + if valid_mask[i - 1] and valid_mask[i]: + diff_prev = multi_gpu_times[i - 1] - single_gpu_times[i - 1] + diff_curr = multi_gpu_times[i] - single_gpu_times[i] + + # A sign change indicates the lines crossed + if diff_prev * diff_curr < 0: + x1, x2 = vertices[i - 1], vertices[i] + y1_s, y2_s = single_gpu_times[i - 1], single_gpu_times[i] + y1_m, y2_m = multi_gpu_times[i - 1], multi_gpu_times[i] + + # Linear interpolation for precise intersection coordinates + m_s = (y2_s - y1_s) / (x2 - x1) + m_m = (y2_m - y1_m) / (x2 - x1) + + if m_s != m_m: + x_cross = x1 + (y1_m - y1_s) / (m_s - m_m) + y_cross = y1_s + m_s * (x_cross - x1) + + plt.plot( + x_cross, + y_cross, + marker="*", + color="#2ca02c", + markersize=15, + zorder=5, + ) + plt.annotate( + f"Crossover:\n~{int(x_cross):,} edges", + xy=(x_cross, y_cross), + xytext=(-20, 40), + textcoords="offset points", + fontsize=10, + fontweight="bold", + color="#2ca02c", + arrowprops=dict( + arrowstyle="->", + connectionstyle="arc3,rad=.2", + color="#2ca02c", + ), + ) + crossover_found = True + break + + # Formatting + plt.title( + f"GNN Distributed Scaling Crossover\nModel: {model_name} | Partitioner: {partitioner} | Feature Dim: {feature_dim}", + fontsize=14, + pad=15, + ) + plt.xlabel("Graph Size (Number of Edges)", fontsize=12) + plt.ylabel("Execution Time (Seconds)", fontsize=12) + plt.xscale( + "log" + ) # Using log scale for x-axis as graph sizes usually scale exponentially + plt.yscale("linear") + + plt.legend(loc="upper left", framealpha=0.9) + plt.tight_layout() + + # Output + plt.savefig(save_path) + print(f"Visualization saved to {save_path}") + if crossover_found: + print(f"Crossover point detected at approximately {int(x_cross):,} vertices.") + else: + print("No crossover point detected in the provided dataset.") + + +# Example usage assuming `payload` is already loaded in your environment: +# plot_crossover_benchmark(payload) + + +def main(): + args = parse_args() + with open(args.input, "r") as f: + payload = json.load(f) + + plot_crossover_benchmark(payload, save_path=args.output) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py b/experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py new file mode 100644 index 0000000..b2ea789 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py @@ -0,0 +1,140 @@ +import glob +import json +import re +import numpy as np +import matplotlib.pyplot as plt +from collections import defaultdict + + +def calculate_crossover(measurements): + """ + Calculates the exact crossover point (num_global_edges) where multi-GPU + execution becomes faster than single-GPU execution using linear interpolation. + """ + # Sort measurements strictly by graph size + measurements = sorted(measurements, key=lambda x: x["params"]["num_global_edges"]) + + vertices = [] + single_times = [] + multi_times = [] + + for m in measurements: + v = m["params"]["num_global_edges"] + + # Extract median times, ignore OOM or missing single GPU data + if m.get("single_gpu_oom", False) or not m.get("single_gpu_trials_seconds"): + continue + + s_time = np.median(m["single_gpu_trials_seconds"]) + m_time = np.median( + m["multi_gpu_trials_seconds_max"] + ) # Use max time for synchronous bottleneck + + vertices.append(v) + single_times.append(s_time) + multi_times.append(m_time) + + for i in range(1, len(vertices)): + diff_prev = multi_times[i - 1] - single_times[i - 1] + diff_curr = multi_times[i] - single_times[i] + + # Sign change indicates the lines crossed + if diff_prev * diff_curr < 0: + x1, x2 = vertices[i - 1], vertices[i] + y1_s, y2_s = single_times[i - 1], single_times[i] + y1_m, y2_m = multi_times[i - 1], multi_times[i] + + m_s = (y2_s - y1_s) / (x2 - x1) + m_m = (y2_m - y1_m) / (x2 - x1) + + if m_s != m_m: + x_cross = x1 + (y1_m - y1_s) / (m_s - m_m) + return x_cross + + return None # No crossover found in this dataset + + +def plot_crossover_dynamics( + file_pattern="results/crossover_*_world_F*.json", + save_path="results/crossover_vs_features.png", +): + """ + Parses benchmark files and plots the crossover point as a function of feature dimension, + with separate lines for each distributed world size. + """ + # Structure: data[world_size][feature_dim] = crossover_point + plot_data = defaultdict(dict) + + # Locate files + files = glob.glob(file_pattern) + if not files: + print(f"No files found matching pattern: {file_pattern}") + return + + # Regex to extract parameters from filename + pattern = re.compile(r"crossover_(\d+)_world_F(\d+)\.json") + + for filepath in files: + match = pattern.search(filepath) + if match: + world_size = int(match.group(1)) + feature_dim = int(match.group(2)) + + try: + with open(filepath, "r") as f: + payload = json.load(f) + + crossover_point = calculate_crossover(payload.get("measurements", [])) + if crossover_point is not None: + plot_data[world_size][feature_dim] = crossover_point + except Exception as e: + print(f"Error processing {filepath}: {e}") + + if not plot_data: + print("No valid crossover points could be calculated from the provided files.") + return + + # Initialize Plot + plt.figure(figsize=(10, 6), dpi=150) + plt.grid(True, which="both", ls="-", alpha=0.3) + + markers = ["o", "s", "^", "D", "v", "p"] + + # Plot lines per world size + for idx, (world_size, dim_data) in enumerate(sorted(plot_data.items())): + # Sort by feature dimension for sequential line plotting + sorted_dims = sorted(dim_data.keys()) + crossovers = [dim_data[d] for d in sorted_dims] + + plt.plot( + sorted_dims, + crossovers, + marker=markers[idx % len(markers)], + linestyle="-", + linewidth=2, + markersize=8, + label=f"{world_size} GPUs", + ) + + # Formatting + plt.title( + "GNN Communication Bottleneck:\nCrossover Threshold vs. Feature Dimension", + fontsize=14, + pad=15, + ) + plt.xlabel("Feature Dimension (F)", fontsize=12) + plt.ylabel("Crossover Point (Number of Edges)", fontsize=12) + + # Depending on the range of your feature dims, a log scale might be preferable for X + # plt.xscale("log", base=2) + plt.yscale("log") # Y-axis (vertices) usually scales exponentially + + plt.legend(title="World Size", loc="upper left", framealpha=0.9) + plt.tight_layout() + + plt.savefig(save_path) + print(f"Scaling dynamics visualization saved successfully to {save_path}") + + +if __name__ == "__main__": + plot_crossover_dynamics()