Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions traincheck/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"torch.fx",
"torch._dynamo",
"torch._sources", # FIXME: cannot handle this module, instrumenting it will lead to exceptions: TypeError: module, class, method, function, traceback, frame, or code object was expected, got builtin_function_or_method
"torchtitan.trainer.Trainer.train_step",
# "torch.autocast",
# "torch.amp",
# "torch.matmul",
Expand Down
2 changes: 1 addition & 1 deletion traincheck/instrumentor/dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def dump_tensor(value):
param_list = tensor_stats(value)
elif tensor_dump_format["dump_tensor_hash"]:
if value.is_cuda:
param_list = tensor_hash(value, with_parallel=True, with_cuda=True)
param_list = tensor_hash(value, with_parallel=True, with_cuda=False)
else:
# TODO: support quick hashing methods for MPS tensors
param_list = tensor_hash(value, with_parallel=True, with_cuda=False)
Expand Down
2 changes: 2 additions & 0 deletions traincheck/instrumentor/proxy_wrapper/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def _reduce_last_axis(x: Tensor) -> Tensor:
def tensor_hash(x: Tensor, with_parallel: bool = True, with_cuda: bool = True) -> int:
if hasattr(x, "_traincheck_tensor_hash"):
return x._traincheck_tensor_hash
if hasattr(x, "to_local"):
x = x.to_local()
if with_parallel:
if x.dtype in [
torch.float32,
Expand Down
3 changes: 2 additions & 1 deletion traincheck/instrumentor/proxy_wrapper/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def subclass_setattr_hook(self, name, value):

@functools.wraps(orig_setattr)
def wrapped_setattr(self, name, value):
hook = getattr(self, SUBCLASS_HOOK_KEY, None)
#hook = getattr(self, SUBCLASS_HOOK_KEY, None)
hook = self.__dict__.get(SUBCLASS_HOOK_KEY)
if hook is not None:
# If hook returns True, skip the original setattr; otherwise continue.
hook(self, name, value)
Expand Down
38 changes: 34 additions & 4 deletions traincheck/instrumentor/tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import functools
import importlib
from importlib.util import find_spec
import inspect
import os
import pathlib
import threading
import time
import traceback
Expand Down Expand Up @@ -522,11 +524,31 @@ def get_module_path_from_file_path(file_path: str, root_module: str) -> str | No
or f"/{root_module}/" not in file_path
):
return None
# get the path of the module from the file path
path_after_root_module = file_path.split(f"/{root_module}/")[1].split(".py")[0]
module_path = f"{root_module}.{path_after_root_module}".replace("/", ".")
return module_path

# get the location of the root module
spec = find_spec(root_module)
if spec is None or spec.origin is None:
# raise error
raise ImportError(f"Cannot locate root module {root_module!r}")

# get the path to the root module
root_module_path = pathlib.PurePath(spec.origin).parent

# parse file_path into a Path object (requires file I/O)
file_path_obj = pathlib.Path(file_path).resolve()

try:
# remove the root module from file_path_obj
relative_path = file_path_obj.relative_to(root_module_path)
except ValueError:
# file_path is not under root module
raise ImportError(f"File path is not under root module {root_module!r}")

# strip off .py
module_name = str(relative_path.with_suffix(''))

# replace / with .
return f"{root_module}.{module_name.replace('/', '.')}"

class Instrumentor:
def __init__(
Expand Down Expand Up @@ -634,6 +656,14 @@ def instrument(self) -> int:

global IS_INSTRUMENTING
IS_INSTRUMENTING = True

try:
import torch.utils._device as _tc_torch_device

_tc_torch_device._device_constructors()
except Exception:
pass

visited_file_paths: set[str] = set()

first_pass_instrumented_count = 0
Expand Down
4 changes: 3 additions & 1 deletion traincheck/trace/trace_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def get_column_dtype(self, column_name: str) -> type:
return self.column_dtypes_cached[column_name]

filtered_values = self.events[column_name].dropna()
filtered_values = filtered_values[filtered_values != MD_NONE()]
#filtered_values = filtered_values[filtered_values != MD_NONE()]
if filtered_values.dtype == object:
filtered_values = filtered_values[filtered_values != MD_NONE()]

if filtered_values.empty:
self.column_dtypes_cached[column_name] = MD_NONE
Expand Down