This commit is contained in:
Justin Chan 2023-10-03 09:48:39 -07:00
parent 40f5484c1c
commit da5bf45da5
1 changed files with 2 additions and 2 deletions

View File

@ -84,7 +84,7 @@ class AttentionPool2d(nn.Module):
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
need_weights=True
)
return x[0]
@ -180,7 +180,7 @@ class ResidualAttentionBlock(nn.Module):
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)[1]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))