【TF/Guide笔记】 06. Training loops
这一节主要侧重于训练的基本原理,怎么用tf代码实现一个简单的线性回归,与tf的设计本身没太大关系,所以本来想跳过了,结果发现了个大bug。
关于fit的使用,compile这一步提供了两个关键的东西,一个optimizer一个loss,这里有一个默认前提是,继承自keras.Model类的call函数,实现的功能必须是由x到y的过程,也就是inputs到prediction,然后通过loss函数(准确来说是loss_function)利用prediction和label算出loss,最后通过optimizer利用gradient计更新weight。
这里跟我理解中偏差最大的是optimizer,通过源码也确认了一下,用来更新weight的值的确是单纯的gradient,也就是通常说的 ,这个g值与loss这轮算出来的值是没有关系的。
然而如果是这样的话,就与我在autodiff那会儿分析出来的结论有差异了,如果我们算出来的是dx而tf算出的是dl/dx,tf并没有最后乘一下loss的话,那结果应该会不一样才对,但是以前我们是对拍过中间结果的。于是我倒回去仔细的推了一下发现,原来以前我们算的也是对的。
对于每一个中间结果u,我都算出来了dl/du,然后把这个结果回传给上游,上游的输入x直接利用dl/du以及x算出了dl/dx,相当于把链式法则的部分融入进了每一个backpass的实现,而这里求出来的导数是正确的。
区别在于,tf可以在计算后指定什么对什么求导,而我们必须在连图的时候就定下目标,从目标的地方连通计算图回路。这样一看,我们的设计竟然比tf更合理,因为训练中不存在对多个target的求导的情况,这一点tf自己也知道,才会在tape.gradient之后自动析构中间结果,而且tf会在确定了链式图的两端之后剪枝掉无用的中间结果,而我们预先连图的方式会在一开始就把这些无用计算剪枝掉。
所谓的连通回路,其实就是把loss.diff设为1的过程,因为对于每一个变量x来说,x.diff实际上存储的是 d target / dx,当target=l的时候,dl/dl自然就是1了,而1作为常量,下游节点可以判断出回路是通的,在计算拓扑的时候会被判定为需要计算。
不过一般来说,计算图里并不会有什么tf算了但最后用不到的梯度,除非用户瞎写。硬要说区别的话,我们把更多的计算融入进了一次计算图里,有空间做更多的优化,但这点当然是可以通过舍弃tape,直接用op写出训练过程来实现了。