diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index 522701d8e..568ad995a 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -2,6 +2,7 @@ import torch import json import shutil +import glob from graph_net.torch import utils from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str @@ -72,12 +73,28 @@ def move_files(self, source_dir, target_dir): target_path = os.path.join(target_dir, item) shutil.move(source_path, target_path) + def _cleanup_stale_data(self, model_path): + for stale_dir in glob.glob(os.path.join(model_path, "subgraph_*")): + shutil.rmtree(stale_dir) + for stale_file_name in ( + "model.py", + "graph_net.json", + "input_meta.py", + "weight_meta.py", + "input_tensor_constraints.py", + "graph_hash.txt", + ): + stale_file = os.path.join(model_path, stale_file_name) + if os.path.isfile(stale_file): + os.remove(stale_file) + def __call__(self, gm: torch.fx.GraphModule, sample_inputs): # 1. Get model path model_path = os.path.join(self.workspace_path, self.name) os.makedirs(model_path, exist_ok=True) if self.subgraph_counter == 0: + self._cleanup_stale_data(model_path) subgraph_path = model_path else: if self.subgraph_counter == 1: @@ -124,17 +141,30 @@ def try_rename_placeholder(node): gm.graph.erase_node(node) assert input_idx == len(sample_inputs) + + # 3. Serialize graph + base_code = serialize_graph_module_to_str(gm) + if self.mut_graph_codes is not None: assert isinstance(self.mut_graph_codes, list) - self.mut_graph_codes.append(serialize_graph_module_to_str(gm)) - # 3. Generate and save model code - base_code = serialize_graph_module_to_str(gm) - # gm.graph.print_tabular() + self.mut_graph_codes.append(base_code) + + # 4. Save tensor metadata + converted = utils.convert_state_and_inputs(params, []) + utils.save_converted_to_text(converted, file_path=subgraph_path) + utils.save_constraints_text( + converted, + file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"), + ) + + # 5. Save model code write_code = utils.apply_templates(base_code) with open(os.path.join(subgraph_path, "model.py"), "w") as fp: fp.write(write_code) - # 4. Save metadata + # 6. Save metadata LAST — graph_net.json serves as the + # completion marker: if it exists, all other files are guaranteed + # to be fully written. metadata = { "framework": "torch", "num_devices_required": 1, @@ -145,15 +175,6 @@ def try_rename_placeholder(node): with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f: json.dump(metadata, f, indent=4) - # 5. Save tensor metadata - # Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors) - converted = utils.convert_state_and_inputs(params, []) - utils.save_converted_to_text(converted, file_path=subgraph_path) - utils.save_constraints_text( - converted, - file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"), - ) - print( f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}" )