当前位置:主页 > 查看内容

Pytorch笔记3-2:“张量操作(补充)”

发布时间:2021-04-27 00:00| 位朋友查看

简介:文章目录 前言 一、合并与分割 1.张量合并 2.张量分割 二、数学运算 1.张量的四则运算 2.张量的幂指运算 3.张量的近似运算 4.裁剪 三、合并与分割 1.范数 2.序号索引 1.未指定索引轴时 2.指定索引轴时 3.保持维度 3.保留前K个值TOP-K 4.逻辑关系 ! 四、where……


前言

补充一些常用的张量操作。


一、合并与分割


1.张量合并

torch.cat([a,b],dim = c) 用于合并张量,但是要保证张量的数据维度可以合并不会出错(即在要合并的轴 数据维度可以不一样,但是其他的轴数据要保持维度相同)。
a,b :指要合并的数据,c 表示要合并的所在轴。

代码如下:

import torch

a = torch.rand(4,3,28,28)
b = torch.rand(6,3,28,28)
#除0轴数据维度可以不一样(46),123轴数据维度都相同故可以合并
c = torch.cat([a,b],dim = 0)
#除1轴数据维度可以不同也可以相同(33);1轴据维度不同(46);
# 23轴数据维度都相同故不可以合并,打印出错
# d = torch.cat([a,b],dim = 1)
print(c.shape)
# print(d.shape)

在这里插入图片描述
在这里插入图片描述
torch.stack([a,b],dim = c) 会在dim指定轴之前增加新的维度。但在指定轴的数据维度比==.cat()== 要求更严格同样必须一致。

代码如下:

import torch

a = torch.rand(4,3,28,28)
b = torch.rand(6,3,28,28)
#除0轴数据维度必须保持相同(46 不同 会出错),123轴数据维度都相同否则会无法合并报错
c = torch.stack([a,b],dim = 0)
print(c.shape)

输出结果:

在这里插入图片描述
代码如下:

import torch

a = torch.rand(6,3,28,28)
b = torch.rand(6,3,28,28)
#除0轴数据维度必须保持相同(66),123轴数据维度都相同否则会无法合并报错
c = torch.stack([a,b],dim = 0)
print(c.shape)

输出结果:

在这里插入图片描述

2.张量分割

.split(a,dim = b) 将张量按照a的指定长度在b轴上进行拆分。
.chunk(a,dim = b) 将张量拆分成a个数量在b轴上进行拆分。

代码如下:

import torch

a = torch.rand(6,3,28,28)
print(a.shape)
#split以长度进行拆分,3是指长度,每n个进行一个拆分,要有接受数据要保持对应
b,c,d =a.split(2,dim=0)
print(b.shape,c.shape,d.shape)
b,c = a.split(3,dim=0)
print(b.shape,c.shape)
print("********************************************************************")
#chunk,是拆分成指定的n个
b,c = a.chunk(2,dim = 0)
print(b.shape,c.shape)
b,c,d =a.chunk(3,dim=0)
print(b.shape,c.shape,d.shape)

输出结果:

在这里插入图片描述


二、数学运算


1.张量的四则运算

加法:若相加数据的维度不同,符合广播机制的会广播后再相加。
+号 可以使用加号进行相加。
torch.add() 也可以调用add方式相加。

代码如下:

import torch

a = torch.rand(4,5)
b = torch.rand(5)
print(a)
print(b)
#因为b的维度不够所有且符合广播机制,torch会将b广播成与a相同然后相加
#加法具有两种实现形式,一种是+号,另一种是调用add方式
print(a+b)
print(torch.add(a,b))

输出结果:
在这里插入图片描述

减法
-号 可以使用重载运算符减号进行相减。
torch.sub() 也可以调用sub(减法:subtraction)方式相减。

代码如下:

import torch

a = torch.rand(4,5)
b = torch.rand(5)
print(a)
print(b)

print(a-b)
print(torch.sub(a,b))

输出结果:
在这里插入图片描述

乘法:乘法分为元素相乘(即对应位置的元素想乘)和矩阵乘法

元素相乘
*号 :可以使用重载运算符星号进行对应元素相乘。
torch.mul() :也可以调用mul(乘法:multiply)方法相乘。

矩阵乘法:需满足矩阵的运算规则,如A的列数(4行5列),等于C的行数(5行8列)得到新的维度(4行8列)
.mm(a,c)号 :仅适用于2D张量矩阵(不推荐)。
@ :重载运算符符号号进行矩阵相乘。
torch.matmul() :也可以调用.matmul()方法进行矩阵相乘。(在3D、4D等多维张量矩阵乘法中,只计算最后两个轴。如(1,2,3,4)@(1,2,4,5)=(1,2,3,5))
代码如下:

import torch

a = torch.rand(4,5)
b = torch.rand(5)
c = torch.rand(5,8)
print(a)
print(b)
#各对应元素相乘
print(a*b)
print(torch.mul(a,b))
#矩阵乘法:torch.mm(仅适用于2D矩阵相乘,不推荐);@符号重载的矩阵乘号;.matmul()函数等三种方法
#矩阵乘法要满足矩阵的运算规则:即A的列数(45),等于的行数C(58列)得到新的维度(48列)
d = a@c
e = torch.matmul(a,c)
print(d,d.shape)
print(e,e.shape)

输出结果:
在这里插入图片描述

除法

/号 :可以使用重载运算符 / 号进行对应元素相除。
torch.div() :也可以调用div(除法:divide)方法相除。

代码如下:

import torch

a = torch.rand(4,5)
b = torch.rand(5)

print(a)
print(b)

#各对应元素相除
print(a/b)
print(torch.div(a,b))

输出结果:
在这里插入图片描述

2.张量的幂指运算

.pow(a) :计算x的a次方。也可以使用两个星号来代替 **()
.sqrt() :开平方根。同样可以使用 **(0.5)
.rsqrt() :开平方根后的倒数。

代码如下:

import torch

a = torch.full((3,3),4)
#平方
b = a.pow(2)
#3次方
c = a**(3)
#开平方根
d = a.sqrt()
e = a.pow(0.5)
#平方根的倒数
f = a.rsqrt()
print(a)
print(b)
print(c)
print(d)
print(e)
print(f)

输出结果:

在这里插入图片描述

.exp(a) :计算以e的a次方。
.log(a) :计算以e为底log(a)。
.rsqrt() :计算以10为底log(a)。

代码如下:

import torch

a = torch.full((3,3),2)
#对a每个数均进行e的指定次方
b = torch.exp(a)
# log默认以2为底
c = torch.log(b)
#log 函数以10为底
d = torch.log10(a)
print(a)
print(b)
print(c)
print(d)

输出结果:
在这里插入图片描述

3.张量的近似运算

.floor() :向下取整。
.ceil(a) :向上取整。
.trunc() :截取整数部分。
.frac() :截取小数部分。
.round() :对小数部分进行四舍五入。

代码如下:

import torch

a = torch.tensor(5.64)
#向下取整
print(a.floor())
#向上取整
print(a.ceil())
#截取整数部分
print(a.trunc())
#截取小数部分
print(a.frac())
#对小鼠部分进行四舍五入
print(a.round())

输出结果:
在这里插入图片描述

4.裁剪

.clamp(a,b) :将数据裁剪到 a 到 b 之间。(常用于对梯度裁剪,防止梯度爆炸的情况出现)

代码如下:

import torch

a = torch.rand(3,3)*20
print(a)
print(a.min())
print(a.median())
print(a.max())
#将数据裁剪到 5-15之间,小于5的以5代替,大于15的以15代替
b = a.clamp(5,15)
print(b)

输出结果:
在这里插入图片描述

三、合并与分割

1.范数

1范数 :所有数据绝对值之和。

在这里插入图片描述

2范数 :所有数据的平方和开根号。
在这里插入图片描述
p范数 :所有数据的p次方和开p根号。
在这里插入图片描述

1范数和2范数,未指定轴分析 :即对所有的数的绝对值求和,以及开根号。
代码如下:

a = torch.full([8],1.)
b = torch.full((2,4),1.)
c = torch.full((2,3,4),1.)
print(a)
print(b)
print(c)
# 1范数:所有数据的绝对值之和 , 2范数平方和开根号
print(a.norm(1),b.norm(1),c.norm(1))
print(a.norm(2),b.norm(2),c.norm(2))

输出结果:
ss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2R4ZjEwMTc1MjQxNTc=,size_16,color_FFFFFF,t_70)

1范数和2范数,以c为例在指定轴分析 :C为3D张量分别再0,1,2三个指定轴求1范数,分析数据的计算。
0轴:c的形状为(2,3,4),在0轴分析,如果未设置轴保留=truse,则1 范数形状应为(3,4)
在这里插入图片描述
1轴:c的形状为(2,3,4),在1轴分析,如果未设置轴保留=truse,则1 范数形状应为(2,4),可以理解对在垂直方向相加。
在这里插入图片描述
2轴:c的形状为(2,3,4),在2轴分析,如果未设置轴保留=truse,则1 范数形状应为(2,3),可以理解对在水平方向相加。
在这里插入图片描述

代码如下:

import torch

c = torch.full((2,3,4),1.)

print(c)
#在指定的轴求范数
#0print(c.norm(1,dim=0))
print(c.norm(2,dim=0))
#1print(c.norm(1,dim=1))
print(c.norm(2,dim=1))
#2print(c.norm(1,dim=2))
print(c.norm(2,dim=2))

2.序号索引

.min() :获取张量数据中的最小值。
.max() :获取张量数据中的最大值。

1.未指定索引轴时

pytorch采用的是将整个张量打平和1D张量,根据最大值和最小值获取位置索引。
.argmin() :获取打平后张量数据中的最小值索引。
.argmax() :获取打平后张量数据中的最大值索引。

2.指定索引轴时

代码如下:

import torch

a = torch.arange(24).view(2,3,4).float()
print(a)
#打印张量的最大值和最小值
print(a.min(),a.max())
#打印张量最大值,最小值对应的索引,无参数指定时默认flatten
print(a.argmin(),a.argmax())
#若不想打平,则需要指定轴
#0轴可以理解未垂直方向取索引
print(a.argmin(dim=0))
print(a.argmax(dim=0))
#1轴可以理解为水平方向取索引
print(a.argmin(dim=1))
print(a.argmax(dim=1))
#1轴可以理解为水平方向取索引
print(a.argmin(dim=2))
print(a.argmax(dim=2))

0轴:a的形状为(2,3,4),在0轴索引分析,是对应位置索引,索引值形状未(3,4)。
在这里插入图片描述
1轴:a的形状为(2,3,4),在1轴索引分析,可以理解为竖向(垂直)取索引,索引值形状未(2,4)。
在这里插入图片描述
2轴:a的形状为(2,3,4),在2轴索引分析,可以理解为横向(水平)取索引,索引值形状未(2,3)。
在这里插入图片描述

3.保持维度

keepdim :对指定的轴取索引时,如果保持轴数不变需要使用 keepdim 保持。

代码如下:

import torch

a = torch.randn(4,10)
print(a)
#打印在1轴最大值及对应索引
print(a.max(dim=1))
#打印索引
print(a.argmax(dim =1))
print("***********************************")
#打印在1轴最大值及对应索引,保留轴
print(a.max(dim=1,keepdim = True))
#打印索引
print(a.argmax(dim =1,keepdim = True))

输出结果:
在这里插入图片描述

3.保留前K个值(TOP-K)

在分类问题中,由于各种原因,可能会出现,分类的某一问题概率值并不高,为了更准确的分类,我们会需要保留的大的前K个概率,进一步判断药分类的类别。
.topk(a) :保留前a个概率值。
.kthvalue(a) :需要注意的是保留第a个小的,并且只能设置为小。
代码如下:

import torch

a = torch.randn(2,8)
print(a)
#largest 默认为True取最大的前k个,False取最小的前k个
#取最大的前3个及对应的索引号
print(a.topk(3,dim=1))
#取最小的前3个及对应的索引号
print(a.topk(3,dim=1,largest = False))
#取第k个小的值,只能取小
print(a.kthvalue(1,dim=1))

输出结果:
在这里插入图片描述

4.逻辑关系( > < = !=)

大于:可以直接用重载运算符 > 或者==.gt()== 大于(great)比较。
小于:可以直接用重载运算符 < 或者==.lt()== 小于比较。
等于:可以直接用重载运算符 == 或者 .eq() 等于(equal)。
不等于:可以直接用重载运算符 != 或者 .not_equal()
代码如下:

import torch

a = torch.arange(9).view(3,3)
print(a)
#大于
print(a>5)
print(torch.gt(a,5))
#小于
print(a<5)
print(torch.lt(a,5))
#等于
print(a == 5)
print(torch.eq(a,5))
#不等于
print(a != 5)
print(torch.not_equal(a,5))

输出结果:
在这里插入图片描述
在这里插入图片描述

四、where与gather

1.where 条件赋值

.where(condition , x, y ) :如果满足条件,会将x中对应元素赋值给输出,不满足则将y对应数值赋给输出。

代码如下:

import torch

condation = torch.randn(3,3)
print(condation)
x = torch.full((3,3),0.)
print(x)
y = torch.full((3,3),1.)
print(y)
#where 用法
print(torch.where(condation>0.5,x,y))

输出结果:

在这里插入图片描述

2.gather

.gather(input , dim, index,out =None) :将数据索引映射到所需要的位置。

代码如下:

import torch

#数据
data = torch.randn(3,6)
print(input)
#索引 在输入的数据中在1轴上去前2个最大值及索引
indexz_data = data.topk(2,dim=1)
print(indexz_data)
idx = indexz_data.indices
#将数据索引映射到另一个位置 [50 - 56]
label = torch.arange(6) + 50
print(label)
#使用gathe进行对应查找
print(torch.gather(label.expand(3,6),dim=1,index = idx))

输出结果:
在这里插入图片描述


总结

本节,承上对Pytorch中常用的一些方法进行补充和解释,敬请小伙伴们批评指正,学习讨论,觉得有价值,劳驾动动食指,点个赞哈。

;原文链接:https://blog.csdn.net/dxf1017524157/article/details/115408378
本站部分内容转载于网络,版权归原作者所有,转载之目的在于传播更多优秀技术内容,如有侵权请联系QQ/微信:153890879删除,谢谢!

推荐图文


随机推荐