资讯

精准传达 • 有效沟通

从品牌网站建设到网络营销策划,从策略到执行的一站式服务

如何解析Pytorch基础中网络参数初始化问题

如何解析Pytorch基础中网络参数初始化问题,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。

成都创新互联专注于二连浩特企业网站建设,成都响应式网站建设公司,商城建设。二连浩特网站建设公司,为二连浩特等地区提供建站服务。全流程定制网站建设,专业设计,全程项目跟踪,成都创新互联专业和态度为您提供的服务

参数访问和遍历:

对于模型参数,我们可以进行访问;

由于Sequential由Module继承而来,所以可以使用Module钟的parameter()或者named_parameters方法来访问所有的参数;

例如,对于使用Sequential搭建的网络,可以使用下列for循环直接进行遍历:

for name, param in net.named_parameters():
    print(name, param.size())

当然,也可以使用索引来按层访问,因为本身网络也是按层搭建的:

for name, param in net[0].named_parameters():
    print(name, param.size(), type(param))

当我们获取某一层的参数信息后,可以使用data()和grad()函数来进行值和梯度的访问:

weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad) # 反向传播前梯度为None
Y.backward()
print(weight_0.grad)

参数初始化问题:

当我们参用for循环获取每层参数,可以采用如下形式对w和偏置b进行初值设定:

for name, param in net.named_parameters():
    if 'weight' in name:
        init.normal_(param, mean=0, std=0.01)
        print(name, param.data)

for name, param in net.named_parameters():
    if 'bias' in name:
        init.constant_(param, val=0)
        print(name, param.data)

当然,我们也可以进行初始化函数的自定义设置:

def init_weight_(tensor):
    with torch.no_grad():
        tensor.uniform_(-10, 10)
        tensor *= (tensor.abs() >= 5).float()

for name, param in net.named_parameters():
    if 'weight' in name:
        init_weight_(param)
        print(name, param.data)

这里注意一下torch.no_grad()的问题;

该形式表示该参数并不随着backward进行更改,常常用来进行局部网络参数固定的情况;

如该连接所示:关于no_grad()

共享参数:

可以自定义Module类,在forward中多次调用同一个层实现;

如上章节的代码所示:

class FancyMLP(nn.Module):
    def __init__(self, **kwargs):
        super(FancyMLP, self).__init__(**kwargs)
        self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数(常数参数)
        self.linear = nn.Linear(20, 20)
    def forward(self, x):
        x = self.linear(x)
        # 使用创建的常数参数,以及nn.functional中的relu函数和mm函数
        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
        # 复用全连接层。等价于两个全连接层共享参数
        x = self.linear(x)
        # 控制流,这里我们需要调用item函数来返回标量进行比较
        while x.norm().item() > 1:
            x /= 2
        if x.norm().item() < 0.8:
            x *= 10
        return x.sum()

所以可以看到,相当于同时在同一个网络中调用两次相同的Linear实例,所以变相实现了参数共享;

suo'yi注意一下,如果传入Sequential模块的多层都是同一个Module实例的话,则他们共享参数;

看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注创新互联行业资讯频道,感谢您对创新互联的支持。


网页题目:如何解析Pytorch基础中网络参数初始化问题
网页路径:http://www.cdkjz.cn/article/gsphcp.html
多年建站经验

多一份参考,总有益处

联系快上网,免费获得专属《策划方案》及报价

咨询相关问题或预约面谈,可以通过以下方式与我们联系

大客户专线   成都:13518219792   座机:028-86922220