下面改动这个代码,方便理解和实验
1 |
|
分析:
1. x.shape
1 | x |
2. x和weight点积
1 | ## 中间过程忽略 |
3. 怎么来的呢
1 |
|
4. y.permute((0, 1, 3, 2))
1 | y.permute((0, 1, 3, 2)) |
5. torch.matmul(torch.matmul(x, weight), y.permute((0, 1, 3, 2)))
1 | torch.matmul(torch.matmul(x, weight), y.permute((0, 1, 3, 2))) |
final. 扩展
1 | # 扩展,如果out为2的话 |