添加 neuron activation distribution
This commit is contained in:
parent
06ae8f34a0
commit
50ac9ee71e
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
# Step 1: Load a pretrained ResNet model
|
||||||
|
resnet = models.resnet18(pretrained=True)
|
||||||
|
resnet.eval()
|
||||||
|
|
||||||
|
# Step 2: Prepare an input image
|
||||||
|
def preprocess_image(image_path):
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
image = Image.open(image_path).convert('RGB')
|
||||||
|
return transform(image).unsqueeze(0)
|
||||||
|
|
||||||
|
image_path = "/content/example.png" # Replace with your image path
|
||||||
|
input_tensor = preprocess_image(image_path)
|
||||||
|
|
||||||
|
# Step 3: Define a hook to capture activations
|
||||||
|
activations = {}
|
||||||
|
|
||||||
|
def hook_fn(module, input, output):
|
||||||
|
activations[module] = output
|
||||||
|
|
||||||
|
# Register hooks for a specific layer
|
||||||
|
layer = resnet.layer4[1].conv2 # Example: last convolutional layer
|
||||||
|
hook = layer.register_forward_hook(hook_fn)
|
||||||
|
|
||||||
|
# Step 4: Forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
resnet(input_tensor)
|
||||||
|
|
||||||
|
# Unregister the hook
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
# Step 5: Plot the activation distribution
|
||||||
|
activation_data = activations[layer].squeeze().cpu().numpy() # Shape: [C, H, W]
|
||||||
|
flattened_data = activation_data.flatten()
|
||||||
|
|
||||||
|
# Use seaborn for better visuals
|
||||||
|
sns.histplot(flattened_data, bins=50, kde=True)
|
||||||
|
plt.title("Neuron Activation Distribution")
|
||||||
|
plt.xlabel("Activation Value")
|
||||||
|
plt.ylabel("Frequency")
|
||||||
|
plt.show()
|
||||||
Loading…
Reference in New Issue