2021年5月8日星期六

Why my neural does not converge using Jax

I am learning Jax, but I encountered an weird question. If I using the code as follows,

import numpy as np  import jax.numpy as jnp  from jax import grad, value_and_grad  from jax import vmap # for auto-vectorizing functions  from functools import partial # for use with vmap  from jax import jit # for compiling functions for speedup  from jax import random # stax initialization uses jax.random  from jax.experimental import stax # neural network library  from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers  import matplotlib.pyplot as plt # visualization    net_init, net_apply = stax.serial(      Dense(40), Relu,      Dense(40), Relu,      Dense(40), Relu,      Dense(1)  )  rng = random.PRNGKey(0)  in_shape = (-1, 1,)  out_shape, params = net_init(rng, in_shape)    def loss(params, X, Y):      predictions = net_apply(params, X)      return jnp.mean((Y - predictions)**2)    @jit  def step(i, opt_state, x1, y1):      p = get_params(opt_state)      val, g = value_and_grad(loss)(p, x1, y1)      return val, opt_update(i, g, opt_state)    opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)  opt_state = opt_init(params)    val_his = []  for i in range(100):      val, opt_state = step(i, opt_state, xrange_inputs, targets)      val_his.append(val)  params = get_params(opt_state)  val_his = jnp.array(val_his)    xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)  targets = jnp.cos(xrange_inputs)  predictions = vmap(partial(net_apply, params))(xrange_inputs)  losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss    plt.plot(xrange_inputs, predictions, label='prediction')  plt.plot(xrange_inputs, losses, label='loss')  plt.plot(xrange_inputs, targets, label='target')  plt.legend()  

the neural network can approximate the function cos(x) well.

But if I rewrite the neural network part by myself as follows

import numpy as np  import jax.numpy as jnp  from jax import grad, value_and_grad  from jax import vmap # for auto-vectorizing functions  from functools import partial # for use with vmap  from jax import jit # for compiling functions for speedup  from jax import random # stax initialization uses jax.random  from jax.experimental import stax # neural network library  from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers  import matplotlib.pyplot as plt # visualization  import numpy as np  from jax.experimental import optimizers  from jax.tree_util import tree_multimap    def initialize_NN(layers, key):              params = []      num_layers = len(layers)      keys = random.split(key, len(layers))      a = jnp.sqrt(0.1)      #params.append(a)      for l in range(0, num_layers-1):          W = xavier_init((layers[l], layers[l+1]), keys[l])          b = jnp.zeros((layers[l+1],), dtype=np.float32)          params.append((W,b))      return params    def xavier_init(size, key):      in_dim = size[0]      out_dim = size[1]            xavier_stddev = jnp.sqrt(2/(in_dim + out_dim))      return random.truncated_normal(key, -2, 2, shape=(out_dim, in_dim), dtype=np.float32)*xavier_stddev        def net_apply(params, X):      num_layers = len(params)      #a = params[0]      for l in range(0, num_layers-1):          W, b = params[l]          X = jnp.maximum(0, jnp.add(jnp.dot(X, W.T), b))      W, b = params[-1]      Y = jnp.dot(X, W.T)+ b      Y = jnp.squeeze(Y)      return Y        def loss(params, X, Y):      predictions = net_apply(params, X)      return jnp.mean((Y - predictions)**2)    key = random.PRNGKey(1)  layers = [1,40,40,40,1]  params = initialize_NN(layers, key)    @jit  def step(i, opt_state, x1, y1):      p = get_params(opt_state)      val, g = value_and_grad(loss)(p, x1, y1)      return val, opt_update(i, g, opt_state)    opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)  opt_state = opt_init(params)    val_his = []  for i in range(10000):      val, opt_state = step(i, opt_state, xrange_inputs, targets)      val_his.append(val)  params = get_params(opt_state)  val_his = jnp.array(val_his)      xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)  targets = jnp.cos(xrange_inputs)  predictions = vmap(partial(net_apply, params))(xrange_inputs)  losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss    plt.plot(xrange_inputs, predictions, label='prediction')  plt.plot(xrange_inputs, losses, label='loss')  plt.plot(xrange_inputs, targets, label='target')  plt.legend()  

my neural network will always converge to a constant, which seems to be trapped by a local minima. I am really confused about that.

The only differences should be initialization, the neural network part and the setting for the parameter params. I have tried different initialization, which make no difference. I wonder if it it because of the setting for optimizing params is wrong, then I can not get convergence.

https://stackoverflow.com/questions/67444953/why-my-neural-does-not-converge-using-jax May 08, 2021 at 03:05PM

没有评论:

发表评论