diff --git a/traincheck/config/config.py b/traincheck/config/config.py index ddbc4a2a..4aee9942 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -21,6 +21,7 @@ "torch.fx", "torch._dynamo", "torch._sources", # FIXME: cannot handle this module, instrumenting it will lead to exceptions: TypeError: module, class, method, function, traceback, frame, or code object was expected, got builtin_function_or_method + "torchtitan.trainer.Trainer.train_step", # "torch.autocast", # "torch.amp", # "torch.matmul", diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index 27c458f8..b808e986 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -296,7 +296,7 @@ def dump_tensor(value): param_list = tensor_stats(value) elif tensor_dump_format["dump_tensor_hash"]: if value.is_cuda: - param_list = tensor_hash(value, with_parallel=True, with_cuda=True) + param_list = tensor_hash(value, with_parallel=True, with_cuda=False) else: # TODO: support quick hashing methods for MPS tensors param_list = tensor_hash(value, with_parallel=True, with_cuda=False) diff --git a/traincheck/instrumentor/proxy_wrapper/hash.py b/traincheck/instrumentor/proxy_wrapper/hash.py index 1c1b3088..d0dae64f 100644 --- a/traincheck/instrumentor/proxy_wrapper/hash.py +++ b/traincheck/instrumentor/proxy_wrapper/hash.py @@ -95,6 +95,8 @@ def _reduce_last_axis(x: Tensor) -> Tensor: def tensor_hash(x: Tensor, with_parallel: bool = True, with_cuda: bool = True) -> int: if hasattr(x, "_traincheck_tensor_hash"): return x._traincheck_tensor_hash + if hasattr(x, "to_local"): + x = x.to_local() if with_parallel: if x.dtype in [ torch.float32, diff --git a/traincheck/instrumentor/proxy_wrapper/subclass.py b/traincheck/instrumentor/proxy_wrapper/subclass.py index ac7ae393..1e059fb7 100644 --- a/traincheck/instrumentor/proxy_wrapper/subclass.py +++ b/traincheck/instrumentor/proxy_wrapper/subclass.py @@ -279,7 +279,8 @@ def subclass_setattr_hook(self, name, value): @functools.wraps(orig_setattr) def wrapped_setattr(self, name, value): - hook = getattr(self, SUBCLASS_HOOK_KEY, None) + #hook = getattr(self, SUBCLASS_HOOK_KEY, None) + hook = self.__dict__.get(SUBCLASS_HOOK_KEY) if hook is not None: # If hook returns True, skip the original setattr; otherwise continue. hook(self, name, value) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 8a6ac1e1..5ba9d984 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -1,7 +1,9 @@ import functools import importlib +from importlib.util import find_spec import inspect import os +import pathlib import threading import time import traceback @@ -522,11 +524,31 @@ def get_module_path_from_file_path(file_path: str, root_module: str) -> str | No or f"/{root_module}/" not in file_path ): return None - # get the path of the module from the file path - path_after_root_module = file_path.split(f"/{root_module}/")[1].split(".py")[0] - module_path = f"{root_module}.{path_after_root_module}".replace("/", ".") - return module_path + # get the location of the root module + spec = find_spec(root_module) + if spec is None or spec.origin is None: + # raise error + raise ImportError(f"Cannot locate root module {root_module!r}") + + # get the path to the root module + root_module_path = pathlib.PurePath(spec.origin).parent + + # parse file_path into a Path object (requires file I/O) + file_path_obj = pathlib.Path(file_path).resolve() + + try: + # remove the root module from file_path_obj + relative_path = file_path_obj.relative_to(root_module_path) + except ValueError: + # file_path is not under root module + raise ImportError(f"File path is not under root module {root_module!r}") + + # strip off .py + module_name = str(relative_path.with_suffix('')) + + # replace / with . + return f"{root_module}.{module_name.replace('/', '.')}" class Instrumentor: def __init__( @@ -634,6 +656,14 @@ def instrument(self) -> int: global IS_INSTRUMENTING IS_INSTRUMENTING = True + + try: + import torch.utils._device as _tc_torch_device + + _tc_torch_device._device_constructors() + except Exception: + pass + visited_file_paths: set[str] = set() first_pass_instrumented_count = 0 diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index 6263bea6..8bb89ca7 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -558,7 +558,9 @@ def get_column_dtype(self, column_name: str) -> type: return self.column_dtypes_cached[column_name] filtered_values = self.events[column_name].dropna() - filtered_values = filtered_values[filtered_values != MD_NONE()] + #filtered_values = filtered_values[filtered_values != MD_NONE()] + if filtered_values.dtype == object: + filtered_values = filtered_values[filtered_values != MD_NONE()] if filtered_values.empty: self.column_dtypes_cached[column_name] = MD_NONE