Commit 3b124cf1 authored by Alexis Brenon's avatar Alexis Brenon
Browse files

Fix agent dumping

parent 788f3cff
......@@ -29,11 +29,11 @@ local class, super = torch.class('NeuralQLearner', 'BaseAgent', module)
-- @tparam string args.inference.class Name of the class to instantiate
-- @tparam table args.inference.params Parameters of the class (see class documentation)
--
-- @tparam table args.memory Parameters of the memory of the agent
-- @tparam number args.memory.pool_size Size of the @{agent._ExperiencePool|experience pool}
-- @tparam number args.memory.history_length Length of the history for inference
-- @tparam number args.memory.history_type Type of history (see @{agent._ExperiencePool})
-- @tparam number args.memory.history_spacing Spacing in history (see @{agent._ExperiencePool})
-- @tparam table args.experience_pool Parameters of the memory of the agent
-- @tparam number args.experience_pool.pool_size Size of the @{agent._ExperiencePool|experience pool}
-- @tparam number args.experience_pool.history_length Length of the history for inference
-- @tparam number args.experience_pool.history_type Type of history (see @{agent._ExperiencePool})
-- @tparam number args.experience_pool.history_spacing Spacing in history (see @{agent._ExperiencePool})
--
-- @tparam[opt=0] number args.learn_start Number of steps after which learning starts
-- @tparam[opt=1] number args.update_freq Learning frequency (epoch size)
......@@ -55,8 +55,12 @@ local class, super = torch.class('NeuralQLearner', 'BaseAgent', module)
-- @tparam[opt=nil] number args.clip_delta Clipping value for delta
-- @tparam[opt=nil] number args.target_q How long a target network is valid
-- @tparam[opt=0] number args.wc L2 weight cost.
function class:__init(args)
super.__init(self, args)
function class:__init(args, dump)
super.__init(self, args, dump)
if not args then
self._logger:warn("No argument provided!")
end
args = args or {}; dump = dump or {}
--- Does the agent must use GPU or not?
self.use_gpu = package.loaded["cutorch"]
......@@ -77,15 +81,16 @@ function class:__init(args)
self:_init_preprocessing_network(args.preprocess, args.observation_size)
args.memory.state_size = self.preprocess.output_size
args.experience_pool = args.experience_pool or {}
args.experience_pool.state_size = self.preprocess.output_size
--- @{agent._ExperiencePool|Experience pool} recording interactions.
--
-- This experience pool will act as a memory for the agent, recording
-- interactions, and returning them when necessary (e.g. when learning).
-- @tfield agent._ExperiencePool self.experience_pool
self.experience_pool = agent._ExperiencePool(args.memory)
self.experience_pool = agent._ExperiencePool(args.experience_pool, dump.experience_pool)
self:_init_inference_network(args.inference)
self:_init_inference_network(args.inference, dump.inference)
--- Number of steps after which learning starts.
-- Add delay to populate the experience pool
......@@ -122,7 +127,7 @@ function class:__init(args)
self.ep_start = args.ep_start or 1
--- Current espilon value.
-- @tfield number self.ep
self.ep = args.ep or self.ep_start
self.ep = dump.ep or self.ep_start
--- Final value of epsilon.
-- @tfield number self.ep_end
self.ep_end = args.ep_end or 0.25
......@@ -172,7 +177,7 @@ function class:__init(args)
decay = 0.95,
mu = 0.01
}
if args.rmsprop then
if dump.rmsprop then
self.rmsprop.mean_square = self._convert_tensor(self.rmsprop.mean_square)
self.rmsprop.mean = self._convert_tensor(self.rmsprop.mean)
end
......@@ -184,8 +189,12 @@ function class:__init(args)
self.target_network = self.inference.network:clone()
end
-- Put the agent in training mode
self:training()
if (dump.evaluating ~= nil) then
-- Put the agent in training mode
self:training()
else
self.evaluating = dump.evaluating
end
end
--- Public Methods
......@@ -281,36 +290,6 @@ function class:dump(cycles)
local default_type = torch.getdefaulttensortype()
-- Save arguments and state variables
local dump = super.dump(self, cycles)
dump.preprocess = copy_table(self.preprocess)
dump.inference = copy_table(self.inference)
dump.rmsprop = copy_table(self.rmsprop)
dump.memory = self.experience_pool:dump()
dump.learn_start = self.learn_start
dump.update_freq = self.update_freq
dump.minibatch_size = self.minibatch_size
dump.n_replay = self.n_replay
dump.rescale_r = self.rescale_r
dump.max_reward = self.max_reward
dump.min_reward = self.min_reward
dump.r_max = self.r_max
dump.ep_start = self.ep_start
dump.ep_end = self.ep_end
dump.ep_endt = self.ep_endt
dump.ep_eval = self.ep_eval
dump.ep = self.ep
dump.lr = self.lr
dump.discount = self.discount
dump.clip_delta = self.clip_delta
dump.target_q = self.target_q
dump.wc = self.wc
dump.experienced_steps = self.experienced_steps
dump.learning_epoch = self.learning_epoch
-- Convert tensors to CPU loadable ones
dump.target_network = self.target_network:clone():type(default_type)
......@@ -324,24 +303,6 @@ function class:dump(cycles)
return dump
end
function class:__tostring__()
local result = torch.typename(self) .. "(\n"
local dump = self:dump()
local copy_table = require('pl.tablex').copy
local rw_dump = copy_table(dump)
rw_dump.rmsprop = copy_table(dump.rmsprop)
rw_dump.rmsprop.mean_square = torch.typename(self.rmsprop.mean_square) .. " of size " .. self.rmsprop.mean_square:size(1)
rw_dump.rmsprop.mean = torch.typename(self.rmsprop.mean) .. " of size " .. self.rmsprop.mean:size(1)
rw_dump.target_network = tostring(self.target_network)
rw_dump.preprocess = copy_table(dump.preprocess)
rw_dump.preprocess.network = tostring(self.preprocess.network)
rw_dump.inference = copy_table(dump.inference)
rw_dump.inference.network = tostring(self.inference.network)
rw_dump.memory = tostring(self.experience_pool)
result = result .. require('pl.pretty').write(rw_dump) .. ")"
return result
end
--- Private Methods
-- @section private-methods
......@@ -363,7 +324,7 @@ function class:_init_preprocessing_network(args, observation_size)
-- Reload a dumped agent
preprocess = args
elseif args then
local network = require("network")
local network = require("arcades.network")
if not network[args.class] then
error(string.format(
"Unable to find '%s' in network package",
......@@ -393,9 +354,10 @@ end
-- @tparam[opt] string args.class Name of the class to instantiate
-- @tparam[opt] string args.file Path to a previously dumped network to reload
-- @tparam[opt] table args.params Parameters
function class:_init_inference_network(args)
function class:_init_inference_network(args, dump)
args = args or {}
-- Check that we reload a dump agent, or load a saved file or instantiate a class
assert(args.network or args.file or args.class, "No network was given to the agent.")
assert(args.network or args.file or args.class or dump, "No network was given to the agent.")
--- Main deep neural network.
-- This network is used to get the best action given a history of preprocessed states
-- @tfield nn.Module self.inference.network The actual network
......@@ -404,12 +366,12 @@ function class:_init_inference_network(args)
-- @tfield torch.Tensor self.inference.parameters Flat view of learnable parameters
-- @tfield torch.Tensor self.inference.grad_parameters Flat view of gradient of energy wrt the learnable parameters
-- @table self.inference
local network = require("network")
local network = require("arcades.network")
local inference
if args and args.network then
-- Restore from a dumped agent
inference = args
if dump then
-- Restore from a dumped agent (new architecture 2017-07-06)
inference = dump
elseif args.file then
network = network.load_all() -- luacheck: no unused
self._logger:info('Loading network from ' .. args.file)
......
......@@ -73,8 +73,9 @@ local class, super = torch.class('_ExperiencePool', 'ArcadesComponent', module)
-- @tparam[opt=1] number args.history_length Length of a full state (with current state plus historic ones)
-- @tparam[opt="linear"] string args.history_type Function used to grab historic states (linear or exp)
-- @tparam[opt=1] number args.history_spacing Spacing parameter for @{history_type}
function class:__init(args)
super.__init(self, args)
function class:__init(args, dump)
super.__init(self, args, dump)
args = args or {}; dump = dump or {}
self.use_gpu = package.loaded["cutorch"]
--- Function to convert tensors if necessary.
--
......@@ -90,13 +91,12 @@ function class:__init(args)
self._convert_tensor = function(t) return t:type(default_type) end
end
self.states = args.states or {}
if self.use_gpu and args.states then
self.states = self._convert_states(args.states, self._convert_tensor)
end
self.hashed_states = args.hashed_states or 0
self.states = dump.states or {}
self.states = self._convert_states(self.states, self._convert_tensor)
self.hashed_states = dump.hashed_states or 0
self.nil_state = self._convert_tensor(
args.nil_state or
dump.nil_state or
torch.zeros(1):repeatTensor(
table.unpack(
(assert(
......@@ -107,20 +107,21 @@ function class:__init(args)
)
)
self.hasher = hash.XXH64(0)
self._no_dump_list.hasher = true
if args.pool then
if dump.pool then
-- Restore a dumped pool
self.pool = args.pool
self.pool = dump.pool
else
self.pool = {}
self.pool.max_size = args.pool_size or 4096
self:clear()
end
self.pushed_pools = args.pushed_pools or {}
self.pushed_pools = dump.pushed_pools or {}
self.history_length = args.history_length or 1
self.history_type = args.history_type or "linear"
self.history_spacing = args.history_spacing or 1
self.history_length = dump.history_length or 1
self.history_type = dump.history_type or "linear"
self.history_spacing = dump.history_spacing or 1
self.history_offsets = self:_compute_history_offsets()
self.history_stacked_state_size = self.nil_state:size():totable()
......@@ -432,14 +433,6 @@ end
function class:dump(cycles)
self:_clean_states()
local dump = super.dump(self, cycles)
dump = {
hashed_states = self.hashed_states,
pool = self.pool,
pushed_pools = self.pushed_pools,
history_length = self.history_length,
history_type = self.history_type,
history_spacing = self.history_spacing
}
local default_type = torch.getdefaulttensortype()
local cpu_converter = function(s) return s:type(default_type) end
dump.states = self._convert_states(self.states, cpu_converter)
......@@ -447,21 +440,6 @@ function class:dump(cycles)
return dump
end
function class:__tostring__()
local dump = self:dump()
local result = dump.classname .. "(\n"
dump.nil_state = nil
dump.states = "hashtable of size : " .. dump.hashed_states
dump.pool = {
max_size = dump.pool.max_size,
}
dump.pushed_pools = #dump.pushed_pools .. " saved pools"
result = result .. require('pl.pretty').write(dump) .. ")"
return result
end
function class._convert_states(states, f)
local result = {}
for h, s in pairs(states) do
......
......@@ -39,7 +39,7 @@ return function(args)
args.fc_layers = args.fc_layers or {
512
}
args.nl_layer = args.nl_layer or nn.ReLU
args.nl_layer = args.nl_layer or "ReLU"
args.output_size = args.output_size or {1, 33}
return network.create_network(args)
......
......@@ -28,7 +28,7 @@ function module.create_network(args)
layer.stride.width, layer.stride.height,
layer.zero_padding.width, layer.zero_padding.height
))
net:add(args.nl_layer(true))
net:add(nn[args.nl_layer](true))
end
-- Convert multidimensionnal output to 1D input
......@@ -38,7 +38,7 @@ function module.create_network(args)
-- Hidden Linear layers
for _, layer_size in ipairs(args.fc_layers) do
net:add(nn.Linear(nelements, layer_size))
net:add(args.nl_layer(true))
net:add(nn[args.nl_layer](true))
nelements = layer_size
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment