赞
踩
import torch outputs = torch.tensor([[0.1, 0.2], [0.05, 0.4]]) print(outputs.argmax(1)) # 里面的1是一个方向,代表矩阵横向 # 输出tensor([1, 1]) 说明0.2在第一行最大,0.4在第二行最大 preds = outputs.argmax(1) targets = torch.tensor([0, 1]) print(preds == targets) # 因为tensor([1, 1])和tensor([0, 1]) # 所以输出tensor([False, True]) print((preds == targets).sum()) # 输出tensor(1),这是对应位置相等的个数 print(outputs.argmax(0)) # 里面的1是一个方向,代表矩阵的列 # 输出tensor([0, 1]) 说明0.1在第一列最大,0.4在第二列最大
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。