赞
踩
title: TensorFlow-静态图和PyTorch-动态图区别
categories:
tags:
最近在重新学习一遍pytorch,之前对于自动求导中的计算图的概念不是很清楚,这里从头看了一遍,有了解一下,简单的写一下自己的笔记。
PyTorch自动求导看起来非常像TensorFlow,这两个框架中,我们都定义了计算图,使用自动微分来计算梯度,但是两者之间最大的不同是TensorFlow的计算图是静态的,而PyTorch使用的是动态的计算图。
在TensorFlow中,我们定义计算图一次,然后后续就会重复执行这个相同的图,后面的话可能只是会提供不同的输入数据,而在PyTorch中,每一个前向通道(forward)定义一个新的计算图。
静态图的好处在于你可以预先对图进行优化。例如:一个框架可能要融合一些图的运算来提升效率,或者产生一个策略来将图分布到多个GPU或者机器上,如果重复使用相同的图,那么再重复运行一个图时,前期潜在的代价高昂的预先优化的消耗就会被分摊开。
静态图和动态图的一个区别是控制流。对于一些模型,我们希望对每个数据点执行不同的计算。例如:一个递归神经网络可能对每个数据点执行不同的时间步数,这个展开(unrolling)可以作为一个循环来实现。
对于一个静态图,循环结构要作为图的一部分,因此TensorFlow提供了运算符来把循环嵌入到图当中。对于动态图来说,情况更加简单,既然我们为每个例子即时创建计算图,我们可以使用普通的命令式控制流来为每个输入执行不同的计算。
tensorflow的forward只会根据第一次模型前向传播来构建一个静态的计算图, 后面的梯度自动求导都是根据这个计算图来计算的, 但是pytorch则不是, 它会为每次forward计算都构建一个动态图的计算图, 后续的每一次迭代都是使用一个新的计算图进行计算的.
作为动态图(网络结构发生变化并不影响计算图计算梯度)和权重共享的一个例子,我们实现了一个非常奇怪的模型:一个全连接的ReLU网络,在每一次前向传播时,它的隐藏层的层数为随机1到4之间的数,这样可以多次重用相同的权重来计算。
因为这个模型可以使用普通的Python流控制来实现循环,并且我们可以通过定义转发时多次重用同一个模块来实现最内层的权重共享。
下面是例子的代码:
import torch import torch.nn as nn import random # fixme: 定义网络 class DynamicNet(nn.Module): def __init__(self, D_in, H, D_out): """ 构造函数,在这里需要将网络的各个模块进行实例化, 并把他们作为成员变量 :params D_in: 输入维度 :params H: 隐藏层维度 :params D_out: 输出维度 """ super(DynamicNet, self).__init__() self.input_layer = nn.Linear(D_in, H) self.hidden_layer = nn.Linear(H, H) self.output_layer = nn.Linear(H, D_out) self.relu = nn.ReLU() def forward(self, x): x1 = self.input_layer(x) t_relu = self.relu(x1) # 定义0-3个隐藏层,利用pytorch动态图的特征,这种做法是可行的 ## 重复调用self.hidden_layer 0-3次,由于pytorch是采用动态图的,因此每一次forward都会创建一个新的动态图,不影响梯度计算 for _ in range(random.randint(0, 3)): t_relu = self.relu(self.hidden_layer(t_relu)) pred = self.output_layer(t_relu) return pred # fixme: 参数配置 dtype = torch.float device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') N, D_in, H, D_out = 64, 1000, 100, 10 # N为批量大小,D_in是输入维度,H是隐藏层维度,D_out是输出层维度 learning_rate = 1e-4 epochs = 100 # fixeme: 创建输入和输出随机张量 input = torch.randn(N, D_in) label = torch.randn(N, D_out) # fixme: 实例化模型 model = DynamicNet(D_in, H, D_out) # fixme: 损失函数的定义 loss_fn = torch.nn.MSELoss(reduction='sum') # fixme: 使用torch.optim定义参数优化器 optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) # 这里用的是Adam # fixme: 训练 for epoch in range(epochs): # 前向过程 pred = model(input) # 计算loss loss = loss_fn(pred, label) print('当前代数:{},当前loss为:{}'.format(epoch,loss.item())) # 模型参数梯度置零 optimizer.zero_grad() # loss反向传播 loss.backward() # 更新权重 optimizer.step()
当前代数:0,当前loss为:702.3717651367188 当前代数:1,当前loss为:700.1735229492188 当前代数:2,当前loss为:697.996826171875 当前代数:3,当前loss为:701.37646484375 当前代数:4,当前loss为:745.0324096679688 当前代数:5,当前loss为:686.2260131835938 当前代数:6,当前loss为:729.3468017578125 当前代数:7,当前loss为:719.3234252929688 当前代数:8,当前loss为:700.3496704101562 当前代数:9,当前loss为:697.38720703125 当前代数:10,当前loss为:676.840576171875 当前代数:11,当前loss为:674.8969116210938 当前代数:12,当前loss为:667.5950927734375 当前代数:13,当前loss为:657.723388671875 当前代数:14,当前loss为:668.5684814453125 当前代数:15,当前loss为:691.20361328125 当前代数:16,当前loss为:629.1569213867188 当前代数:17,当前loss为:662.6659545898438 当前代数:18,当前loss为:611.5732421875 当前代数:19,当前loss为:602.4924926757812 当前代数:20,当前loss为:698.65625 当前代数:21,当前loss为:583.9902954101562 当前代数:22,当前loss为:655.1369018554688 当前代数:23,当前loss为:566.2965087890625 当前代数:24,当前loss为:557.3516845703125 当前代数:25,当前loss为:547.8927001953125 当前代数:26,当前loss为:650.374755859375 当前代数:27,当前loss为:648.9633178710938 当前代数:28,当前loss为:697.7493286132812 当前代数:29,当前loss为:515.4136962890625 当前代数:30,当前loss为:507.982421875 当前代数:31,当前loss为:500.00030517578125 当前代数:32,当前loss为:491.6020812988281 当前代数:33,当前loss为:482.8730163574219 当前代数:34,当前loss为:697.01220703125 当前代数:35,当前loss为:687.0211791992188 当前代数:36,当前loss为:638.8866577148438 当前代数:37,当前loss为:696.5703735351562 当前代数:38,当前loss为:636.4703369140625 当前代数:39,当前loss为:685.7159423828125 当前代数:40,当前loss为:685.2298583984375 当前代数:41,当前loss为:684.5957641601562 当前代数:42,当前loss为:695.6959228515625 当前代数:43,当前loss为:427.000732421875 当前代数:44,当前loss为:682.4891967773438 当前代数:45,当前loss为:419.1412048339844 当前代数:46,当前loss为:681.0149536132812 当前代数:47,当前loss为:694.716552734375 当前代数:48,当前loss为:406.95831298828125 当前代数:49,当前loss为:678.7310180664062 当前代数:50,当前loss为:623.3473510742188 当前代数:51,当前loss为:395.0804443359375 当前代数:52,当前loss为:693.6585083007812 当前代数:53,当前loss为:675.7344360351562 当前代数:54,当前loss为:618.6995239257812 当前代数:55,当前loss为:616.9414672851562 当前代数:56,当前loss为:614.5744018554688 当前代数:57,当前loss为:692.4542236328125 当前代数:58,当前loss为:609.0692138671875 当前代数:59,当前loss为:691.904052734375 当前代数:60,当前loss为:603.1943359375 当前代数:61,当前loss为:669.9990234375 当前代数:62,当前loss为:691.0020141601562 当前代数:63,当前loss为:594.264892578125 当前代数:64,当前loss为:591.1102294921875 当前代数:65,当前loss为:666.8950805664062 当前代数:66,当前loss为:665.9771728515625 当前代数:67,当前loss为:689.326171875 当前代数:68,当前loss为:688.93701171875 当前代数:69,当前loss为:688.5073852539062 当前代数:70,当前loss为:574.0647583007812 当前代数:71,当前loss为:661.115234375 当前代数:72,当前loss为:660.0462036132812 当前代数:73,当前loss为:566.465576171875 当前代数:74,当前loss为:360.7765197753906 当前代数:75,当前loss为:685.8043212890625 当前代数:76,当前loss为:357.83026123046875 当前代数:77,当前loss为:556.7740478515625 当前代数:78,当前loss为:684.478515625 当前代数:79,当前loss为:683.9954223632812 当前代数:80,当前loss为:683.468994140625 当前代数:81,当前loss为:650.9133911132812 当前代数:82,当前loss为:545.8726806640625 当前代数:83,当前loss为:543.443359375 当前代数:84,当前loss为:540.5396118164062 当前代数:85,当前loss为:342.5977478027344 当前代数:86,当前loss为:645.627685546875 当前代数:87,当前loss为:644.410888671875 当前代数:88,当前loss为:642.9930419921875 当前代数:89,当前loss为:641.402099609375 当前代数:90,当前loss为:524.3102416992188 当前代数:91,当前loss为:521.52880859375 当前代数:92,当前loss为:518.2946166992188 当前代数:93,当前loss为:676.001220703125 当前代数:94,当前loss为:633.2501220703125 当前代数:95,当前loss为:674.614501953125 当前代数:96,当前loss为:673.8367919921875 当前代数:97,当前loss为:628.274169921875 当前代数:98,当前loss为:626.4828491210938 当前代数:99,当前loss为:325.72052001953125
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。