Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local variable serialization for workload training errors #4084

Open
landenai opened this issue Feb 20, 2025 · 0 comments
Open

Local variable serialization for workload training errors #4084

landenai opened this issue Feb 20, 2025 · 0 comments

Comments

@landenai
Copy link

landenai commented Feb 20, 2025

Problem Statement

Errors from training workloads (e.g. GPU) result in local variables that are unserializable or have repr/str methods that call out to a device. Attaching these local variables to Sentry results in segfault errors. While I have the option to not include any local variables with my error event payloads, this is an all or nothing approach, which can lead to the loss of valuable debug context.

Solution Brainstorm

Framework-specific logic within the SDK and an accompanying flag to only extract tensor metadata. Here's an example:

def extract_tensor_metadata(obj):
    """Safely extract metadata from tensor objects across frameworks"""
    metadata = {'type': type(obj).__name__}
    
    # Framework-specific safe extraction methods
    if 'torch' in sys.modules and isinstance(obj, sys.modules['torch'].Tensor):
        # PyTorch tensors
        try:
            metadata['shape'] = tuple(obj.shape)
            metadata['dtype'] = str(obj.dtype)
            metadata['device'] = str(obj.device)
            metadata['requires_grad'] = bool(obj.requires_grad)
            # Don't access .data or tensor contents
        except Exception as e:
            metadata['extraction_error'] = str(e)
    
    elif 'jax.numpy' in sys.modules and isinstance(obj, sys.modules['jax.numpy'].ndarray):
        # JAX arrays
        try:
            metadata['shape'] = obj.shape
            metadata['dtype'] = str(obj.dtype)
            # Safely extract device info
            if hasattr(obj, 'device_buffer'):
                try:
                    metadata['device'] = str(obj.device_buffer.device())
                except:
                    metadata['device'] = 'unknown_jax_device'
        except Exception as e:
            metadata['extraction_error'] = str(e)
    
    elif 'tensorflow' in sys.modules and isinstance(obj, sys.modules['tensorflow'].Tensor):
        # TensorFlow tensors
        try:
            metadata['shape'] = tuple(obj.shape)
            metadata['dtype'] = str(obj.dtype)
            # TF device extraction can be complicated
            try:
                if hasattr(obj, 'device'):
                    metadata['device'] = str(obj.device)
                elif hasattr(obj, '_device'):
                    metadata['device'] = str(obj._device)
            except:
                metadata['device'] = 'unknown_tf_device'
        except Exception as e:
            metadata['extraction_error'] = str(e)
    
    # Add fallbacks for common attributes if not already set
    for attr in ['shape', 'dtype']:
        if attr not in metadata and hasattr(obj, attr):
            try:
                val = getattr(obj, attr)
                metadata[attr] = str(val) if val is not None else None
            except:
                pass
    
    return metadata
@landenai landenai changed the title Smarter local variable serialization for workload training errors Local variable serialization for workload training errors Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant