Fixed documentation for Zero-Shot Prediction

norm must be done with dim=1 not dim=-1
This commit is contained in:
Peter Kuhar 2023-01-26 16:58:42 -08:00 committed by GitHub
parent 3702849800
commit 6c0d4766aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -112,8 +112,8 @@ with torch.no_grad():
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
image_features /= image_features.norm(dim=1, keepdim=True)
text_features /= text_features.norm(dim=1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)