fix inefficient attention computation

This commit is contained in:
Penn 2022-07-21 11:35:03 -07:00
parent 4d120f3ec3
commit 8c0b98adfc
1 changed files with 4 additions and 5 deletions

View File

@ -67,10 +67,10 @@ class AttentionPool2d(nn.Module):
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
@ -88,8 +88,7 @@ class AttentionPool2d(nn.Module):
training=self.training,
need_weights=False
)
return x[0]
return x.squeeze(0)
class ModifiedResNet(nn.Module):