I am learning Jax, but I encountered an weird question. If I using the code as follows,
import numpy as np import jax.numpy as jnp from jax import grad, value_and_grad from jax import vmap # for auto-vectorizing functions from functools import partial # for use with vmap from jax import jit # for compiling functions for speedup from jax import random # stax initialization uses jax.random from jax.experimental import stax # neural network library from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers import matplotlib.pyplot as plt # visualization net_init, net_apply = stax.serial( Dense(40), Relu, Dense(40), Relu, Dense(40), Relu, Dense(1) ) rng = random.PRNGKey(0) in_shape = (-1, 1,) out_shape, params = net_init(rng, in_shape) def loss(params, X, Y): predictions = net_apply(params, X) return jnp.mean((Y - predictions)**2) @jit def step(i, opt_state, x1, y1): p = get_params(opt_state) val, g = value_and_grad(loss)(p, x1, y1) return val, opt_update(i, g, opt_state) opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3) opt_state = opt_init(params) val_his = [] for i in range(100): val, opt_state = step(i, opt_state, xrange_inputs, targets) val_his.append(val) params = get_params(opt_state) val_his = jnp.array(val_his) xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1) targets = jnp.cos(xrange_inputs) predictions = vmap(partial(net_apply, params))(xrange_inputs) losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss plt.plot(xrange_inputs, predictions, label='prediction') plt.plot(xrange_inputs, losses, label='loss') plt.plot(xrange_inputs, targets, label='target') plt.legend()
the neural network can approximate the function cos(x)
well.
But if I rewrite the neural network part by myself as follows
import numpy as np import jax.numpy as jnp from jax import grad, value_and_grad from jax import vmap # for auto-vectorizing functions from functools import partial # for use with vmap from jax import jit # for compiling functions for speedup from jax import random # stax initialization uses jax.random from jax.experimental import stax # neural network library from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers import matplotlib.pyplot as plt # visualization import numpy as np from jax.experimental import optimizers from jax.tree_util import tree_multimap def initialize_NN(layers, key): params = [] num_layers = len(layers) keys = random.split(key, len(layers)) a = jnp.sqrt(0.1) #params.append(a) for l in range(0, num_layers-1): W = xavier_init((layers[l], layers[l+1]), keys[l]) b = jnp.zeros((layers[l+1],), dtype=np.float32) params.append((W,b)) return params def xavier_init(size, key): in_dim = size[0] out_dim = size[1] xavier_stddev = jnp.sqrt(2/(in_dim + out_dim)) return random.truncated_normal(key, -2, 2, shape=(out_dim, in_dim), dtype=np.float32)*xavier_stddev def net_apply(params, X): num_layers = len(params) #a = params[0] for l in range(0, num_layers-1): W, b = params[l] X = jnp.maximum(0, jnp.add(jnp.dot(X, W.T), b)) W, b = params[-1] Y = jnp.dot(X, W.T)+ b Y = jnp.squeeze(Y) return Y def loss(params, X, Y): predictions = net_apply(params, X) return jnp.mean((Y - predictions)**2) key = random.PRNGKey(1) layers = [1,40,40,40,1] params = initialize_NN(layers, key) @jit def step(i, opt_state, x1, y1): p = get_params(opt_state) val, g = value_and_grad(loss)(p, x1, y1) return val, opt_update(i, g, opt_state) opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3) opt_state = opt_init(params) val_his = [] for i in range(10000): val, opt_state = step(i, opt_state, xrange_inputs, targets) val_his.append(val) params = get_params(opt_state) val_his = jnp.array(val_his) xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1) targets = jnp.cos(xrange_inputs) predictions = vmap(partial(net_apply, params))(xrange_inputs) losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss plt.plot(xrange_inputs, predictions, label='prediction') plt.plot(xrange_inputs, losses, label='loss') plt.plot(xrange_inputs, targets, label='target') plt.legend()
my neural network will always converge to a constant, which seems to be trapped by a local minima. I am really confused about that.
The only differences should be initialization, the neural network part and the setting for the parameter params
. I have tried different initialization, which make no difference. I wonder if it it because of the setting for optimizing params
is wrong, then I can not get convergence.
https://stackoverflow.com/questions/67444953/why-my-neural-does-not-converge-using-jax May 08, 2021 at 03:05PM