This commit is contained in:
parent
40f5484c1c
commit
da5bf45da5
|
|
@ -84,7 +84,7 @@ class AttentionPool2d(nn.Module):
|
|||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
need_weights=True
|
||||
)
|
||||
|
||||
return x[0]
|
||||
|
|
@ -180,7 +180,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)[1]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
|
|
|
|||
Loading…
Reference in New Issue