73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import torch
|
|
import torchvision.models as models
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
# Step 1: Load a pretrained ResNet model
|
|
resnet = models.resnet18(pretrained=True)
|
|
resnet.eval()
|
|
|
|
# Step 2: Prepare an input image
|
|
def preprocess_image(image_path):
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
image = Image.open(image_path).convert('RGB')
|
|
return transform(image).unsqueeze(0)
|
|
|
|
image_path = "/content/example.png" # Replace with your image path
|
|
input_tensor = preprocess_image(image_path)
|
|
|
|
# Step 3: Define a hook to capture activations
|
|
activations = {}
|
|
|
|
def hook_fn(module, input, output):
|
|
activations[module] = output
|
|
|
|
# Register hooks for a specific layer
|
|
layer = resnet.layer4[1].conv2 # Example: last convolutional layer
|
|
hook = layer.register_forward_hook(hook_fn)
|
|
|
|
# Step 4: Forward pass
|
|
with torch.no_grad():
|
|
resnet(input_tensor)
|
|
|
|
# Unregister the hook
|
|
hook.remove()
|
|
|
|
# Step 5: Plot the activation distribution
|
|
activation_data = activations[layer].squeeze().cpu().numpy() # Shape: [C, H, W]
|
|
flattened_data = activation_data.flatten()
|
|
|
|
# Use seaborn for better visuals
|
|
sns.histplot(flattened_data, bins=50, kde=True)
|
|
plt.title("Neuron Activation Distribution")
|
|
plt.xlabel("Activation Value")
|
|
plt.ylabel("Frequency")
|
|
plt.show()
|
|
|
|
# sns.histplot(flattened_data, bins=50, kde=True, stat='density')
|
|
# plt.title("Neuron Activation Distribution")
|
|
# plt.xlabel("Activation Value")
|
|
# plt.ylabel("Density")
|
|
# plt.show()
|
|
|
|
|
|
# from sklearn.preprocessing import StandardScaler
|
|
|
|
# # Normalize activation values to z-score
|
|
# scaler = StandardScaler()
|
|
# normalized_data = scaler.fit_transform(flattened_data.reshape(-1, 1)).flatten()
|
|
|
|
# # Plot the normalized activation distribution
|
|
# sns.histplot(normalized_data, bins=50, kde=True, stat='density')
|
|
# plt.title("Neuron Activation Distribution (Normalized Features)")
|
|
# plt.xlabel("Normalized Activation Value (Z-Score)")
|
|
# plt.ylabel("Density")
|
|
# plt.show()
|
|
|