欢迎光临散文网 会员登陆 & 注册

【TF/Guide笔记】 06. Training loops

2022-03-01 14:44 作者:纪一希  | 我要投稿

    这一节主要侧重于训练的基本原理,怎么用tf代码实现一个简单的线性回归,与tf的设计本身没太大关系,所以本来想跳过了,结果发现了个大bug。

    关于fit的使用,compile这一步提供了两个关键的东西,一个optimizer一个loss,这里有一个默认前提是,继承自keras.Model类的call函数,实现的功能必须是由x到y的过程,也就是inputs到prediction,然后通过loss函数(准确来说是loss_function)利用prediction和label算出loss,最后通过optimizer利用gradient计更新weight。

    这里跟我理解中偏差最大的是optimizer,通过源码也确认了一下,用来更新weight的值的确是单纯的gradient,也就是通常说的 w%20%3D%20w%20-%20%5Calpha%20*g,这个g值与loss这轮算出来的值是没有关系的。

    然而如果是这样的话,就与我在autodiff那会儿分析出来的结论有差异了,如果我们算出来的是dx而tf算出的是dl/dx,tf并没有最后乘一下loss的话,那结果应该会不一样才对,但是以前我们是对拍过中间结果的。于是我倒回去仔细的推了一下发现,原来以前我们算的也是对的。

    对于每一个中间结果u,我都算出来了dl/du,然后把这个结果回传给上游,上游的输入x直接利用dl/du以及x算出了dl/dx,相当于把链式法则的部分融入进了每一个backpass的实现,而这里求出来的导数是正确的。

    区别在于,tf可以在计算后指定什么对什么求导,而我们必须在连图的时候就定下目标,从目标的地方连通计算图回路。这样一看,我们的设计竟然比tf更合理,因为训练中不存在对多个target的求导的情况,这一点tf自己也知道,才会在tape.gradient之后自动析构中间结果,而且tf会在确定了链式图的两端之后剪枝掉无用的中间结果,而我们预先连图的方式会在一开始就把这些无用计算剪枝掉。

    所谓的连通回路,其实就是把loss.diff设为1的过程,因为对于每一个变量x来说,x.diff实际上存储的是 d target / dx,当target=l的时候,dl/dl自然就是1了,而1作为常量,下游节点可以判断出回路是通的,在计算拓扑的时候会被判定为需要计算。

    不过一般来说,计算图里并不会有什么tf算了但最后用不到的梯度,除非用户瞎写。硬要说区别的话,我们把更多的计算融入进了一次计算图里,有空间做更多的优化,但这点当然是可以通过舍弃tape,直接用op写出训练过程来实现了。

【TF/Guide笔记】 06. Training loops的评论 (共 条)

分享到微博请遵守国家法律