2021年3月24日星期三

How to implement neural network based reinforcement learning (DQN, Actor-Critic, ..) in Java? (Custom Loss Backpropagation)

Note: My problem is the same as described by someone else here => Backpropagate custom loss in DL4J

I want to use neural network based approaches to make an agent learn how to behave in an environment, i.e., using states, actions, rewards. The environment is implemented in Java so I intended to use DL4J library to imlement the networks. However, I am struggling with the backpropagation (BPG) as the desired networks (e.g. actor critic) require custom loss computation that is then further used in the BPG. I cannot get it to work; While building the network with MultiLayerNetwork and training with fit(..), or with manually settings inputs and computing gradients with optimizer works there seems to be no option to calculate a loss that should be used in BPG instead of the outputs in output layer.

I thought about using another library but DL4J already seemed the most sophisticated one to me. Also, maybe I could implement the networks in python and try to train the network by sending the training samples/labels via interprocess communication between the java program and python script although this seems complicated.

Anyone had a similar experience? Any ideas on how to go with this?

https://stackoverflow.com/questions/66792045/how-to-implement-neural-network-based-reinforcement-learning-dqn-actor-critic March 25, 2021 at 10:04AM

没有评论:

发表评论