diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index 568ad995a..949b71321 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -8,7 +8,10 @@ torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True -torch._dynamo.config.capture_sparse_compute = True +try: + torch._dynamo.config.capture_sparse_compute = True +except AttributeError: + pass torch._dynamo.config.raise_on_ctx_manager_usage = False torch._dynamo.config.allow_rnn = True diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index e50d3b7e9..25ca3b81f 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -27,7 +27,7 @@ def __call__(self): module, forward_inputs = get_torch_module_and_inputs( self.model_path, use_dummy_inputs=False, device=self.device ) - module.train() + module.eval() eval_forward_dir = os.path.join( self.output_dir, "eval_forward", self.rel_model_path @@ -35,6 +35,10 @@ def __call__(self): if not os.path.exists(eval_forward_dir): shutil.copytree(self.model_path, eval_forward_dir) + forward_inputs = [ + inp.detach().clone() if isinstance(inp, torch.Tensor) else inp + for inp in forward_inputs + ] forward_inputs = self.set_requires_grad_for_forward_inputs( self.model_path, module, forward_inputs )