After my last question about confusion over training a recurrent NN using Flux, I dove deeper into the Flux training process, and now I'm even more confused. I think my trouble is centered around using a sum in the loss function, so that the loss takes into account many points in a sequence. See here, where the loss is defined as:
loss(x, y) = sum((Flux.stack(m.(x),1) .- y) .^ 2)
If x is a sequence with multiple points, and y is the corresponding output for each point, this loss function evaluates the loss for the whole sequence. What I'm trying to understand is how Flux takes the gradient of a function like this. Imagine simplifying it to:
L(x, y) = sum((Flux.stack(m.(xs), 1) .- y))
We can also create a very simple recurrent neural "network" as a single 1 -> 1 node with no activation function:
m = Flux.RNN(1, 1, x -> x)
This is (sort of) equivalent to:
h = [0.0] function m(x) y = Wx .* x + Wh .* h .+ b global h = y return y end What's the gradient of loss with respect to Wx? Take a sequence with two points, x = [x1, x2] and y* = [y1*, y2*]. Put x1 through the RNN and you get:
y1 = h2 = Wx*x1 + Wh*h1 + b
Then put x2 through and you get:
y2 = h3 = Wx*x2 + Wh*h2 + b = Wx*x2 + Wh*(Wx*x1 + Wh*h1 + b) + b.
Now calculate the loss:
L = y1 - y1* + y2 - y2* = Wx*x1 + Wh*h1 + b - y1* + Wx*x2 + Wh*(Wx*x1 + Wh*h1 + b) + b - y2*
It seems obvious that dL/dWx should be x1 + x2 + Wh*x1. So let's say x and y are:
x = [[0.3], [2.5]] y = [0.5, 1.0] and the parameters are initialized to:
Wxs = [0.5] Whs = [0.001] bs = [0.85] If you calculate dL/DWx = x1 + x2 + Wh*x1, it's 2.8003. You could also try the finite difference:
h = [0.0] q = loss(x, y) Wx .+= 0.01 h = [0.0] r = loss(x, y) abs(q - r)/0.01 # = 2.8003 and get 2.8003. But if you use Flux's gradient function:
Wx = [0.5] h = [0.0] gs = gradient(() -> loss(x, y), params(Wx, Wh, b)) gs[Wxs] # = 2.8025 you get 2.8025, which seems to be x1 + x2 + Wh*x2. I don't understand why the results are different, especially considering that everything is in agreement when evaluating the two different loss functions themselves. Is there something I'm overlooking? Is there something weird going on in gradient?
没有评论:
发表评论