1.torch.mul / mm / matlmul / bmm 等区别
1、torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵;
2、torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵。
3.torch.bmm()强制规定维度和大小相同,torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。
4.torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作
当进行操作的两个tensor都是3D时,两者等同。
当输入是都是二维时,就是普通的矩阵乘法,和tensor.mm函数用法相同。
当输入有多维时,把多出的一维作为batch提出来,其他部分做矩阵乘法。
。
下面看一个两个都是3维的例子。
。
将b的第0维1broadcast成2提出来,后两维做矩阵乘法即可。
再看一个复杂一点的,是官网的例子。
。
2.random.shuffle函数
在设置好batch后,需要随机抽取,那么可以对数据总量进行获取,并生成一个list,通过random.shuffle(list) 打乱序号,以进行随机抽取
random.shuffle(list ): 功能就是打乱list中元素的顺序
1 | def data_iter(batch_size,features,labels): |