Commit ce2b15cc authored by Alexis Brenon's avatar Alexis Brenon
Browse files

馃攢 Merge branch 'master' into xp/A1-real-annot-1

Conflicts:
	arcades/environment/smarthome/sweethome/GraphicalAnnotatedSweetHome.lua
	expe_template.lua
	src/environment/states/datasource/init.lua
parents 5159f509 8da75836
......@@ -4,3 +4,6 @@ data
# Vim swap files
.*.swp
# TREPL history file
/.history
......@@ -5,16 +5,14 @@
local torch = require('torch')
local arcades = require('arcades')
local module = {}
local class = torch.class('BaseAgent', module)
local class, super = torch.class('BaseAgent', 'ArcadesComponent', module)
--- Abstract constructor.
function class:__init()
error(string.format(
"The %s class is not intented to be instanciated.",
torch.typename(self)
),
2)
function class:__init(args)
super.__init(self, args)
end
--- Interface
......@@ -110,22 +108,5 @@ function class:report()
return ""
end
--- Dump informations of the agent.
--
-- Return some useful informations that can be reloaded later to restore the
-- agent.
-- @return Any serializable object
function class:dump()
return {
classname = torch.type(self)
}
end
--luacheck: pop
function class:__tostring__()
local result = torch.typename(self) .. "(\n"
result = result .. require('pl.pretty').write(self:dump()) .. ")"
return result
end
return module.BaseAgent
......@@ -8,9 +8,9 @@
local nn = require('nn')
local torch = require('torch')
local logging = require('utils.logging')
local arcades = require('arcades')
local agent = require('agent')
local agent = require('arcades.agent')
assert(agent.BaseAgent)
local module = {}
local class, super = torch.class('NeuralQLearner', 'BaseAgent', module)
......@@ -29,11 +29,11 @@ local class, super = torch.class('NeuralQLearner', 'BaseAgent', module)
-- @tparam string args.inference.class Name of the class to instantiate
-- @tparam table args.inference.params Parameters of the class (see class documentation)
--
-- @tparam table args.memory Parameters of the memory of the agent
-- @tparam number args.memory.pool_size Size of the @{agent._ExperiencePool|experience pool}
-- @tparam number args.memory.history_length Length of the history for inference
-- @tparam number args.memory.history_type Type of history (see @{agent._ExperiencePool})
-- @tparam number args.memory.history_spacing Spacing in history (see @{agent._ExperiencePool})
-- @tparam table args.experience_pool Parameters of the memory of the agent
-- @tparam number args.experience_pool.pool_size Size of the @{agent._ExperiencePool|experience pool}
-- @tparam number args.experience_pool.history_length Length of the history for inference
-- @tparam number args.experience_pool.history_type Type of history (see @{agent._ExperiencePool})
-- @tparam number args.experience_pool.history_spacing Spacing in history (see @{agent._ExperiencePool})
--
-- @tparam[opt=0] number args.learn_start Number of steps after which learning starts
-- @tparam[opt=1] number args.update_freq Learning frequency (epoch size)
......@@ -55,9 +55,12 @@ local class, super = torch.class('NeuralQLearner', 'BaseAgent', module)
-- @tparam[opt=nil] number args.clip_delta Clipping value for delta
-- @tparam[opt=nil] number args.target_q How long a target network is valid
-- @tparam[opt=0] number args.wc L2 weight cost.
function class:__init(args)
self.logger = logging.new(logging.timestamp_output("NeuralQLearner"))
self.logger:setLevel(logging.default_level)
function class:__init(args, dump)
super.__init(self, args, dump)
if not args then
self._logger:warn("No argument provided!")
end
args = args or {}; dump = dump or {}
--- Does the agent must use GPU or not?
self.use_gpu = package.loaded["cutorch"]
......@@ -78,15 +81,16 @@ function class:__init(args)
self:_init_preprocessing_network(args.preprocess, args.observation_size)
args.memory.state_size = self.preprocess.output_size
args.experience_pool = args.experience_pool or {}
args.experience_pool.state_size = self.preprocess.output_size
--- @{agent._ExperiencePool|Experience pool} recording interactions.
--
-- This experience pool will act as a memory for the agent, recording
-- interactions, and returning them when necessary (e.g. when learning).
-- @tfield agent._ExperiencePool self.experience_pool
self.experience_pool = agent._ExperiencePool(args.memory)
self.experience_pool = agent._ExperiencePool(args.experience_pool, dump.experience_pool)
self:_init_inference_network(args.inference)
self:_init_inference_network(args.inference, dump.inference)
--- Number of steps after which learning starts.
-- Add delay to populate the experience pool
......@@ -123,7 +127,7 @@ function class:__init(args)
self.ep_start = args.ep_start or 1
--- Current espilon value.
-- @tfield number self.ep
self.ep = args.ep or self.ep_start
self.ep = dump.ep or self.ep_start
--- Final value of epsilon.
-- @tfield number self.ep_end
self.ep_end = args.ep_end or 0.25
......@@ -173,7 +177,7 @@ function class:__init(args)
decay = 0.95,
mu = 0.01
}
if args.rmsprop then
if dump.rmsprop then
self.rmsprop.mean_square = self._convert_tensor(self.rmsprop.mean_square)
self.rmsprop.mean = self._convert_tensor(self.rmsprop.mean)
end
......@@ -185,8 +189,12 @@ function class:__init(args)
self.target_network = self.inference.network:clone()
end
-- Put the agent in training mode
self:training()
if (dump.evaluating ~= nil) then
-- Put the agent in training mode
self:training()
else
self.evaluating = dump.evaluating
end
end
--- Public Methods
......@@ -260,7 +268,7 @@ end
--- Overriden method
-- @see BaseAgent:training
function class:training()
self.logger:debug("Passing in TRAINING mode")
self._logger:debug("Passing in TRAINING mode")
self.evaluating = false
self.experience_pool:pop()
return self
......@@ -269,7 +277,7 @@ end
--- Overriden method
-- @see BaseAgent:evaluate
function class:evaluate()
self.logger:debug("Passing in EVALUATING mode")
self._logger:debug("Passing in EVALUATING mode")
self.evaluating = true
self.experience_pool:push()
return self
......@@ -277,41 +285,11 @@ end
--- Overriden method
-- @see BaseAgent:dump
function class:dump()
function class:dump(cycles)
local copy_table = require('pl.tablex').copy
local default_type = torch.getdefaulttensortype()
-- Save arguments and state variables
local dump = super.dump(self)
dump.preprocess = copy_table(self.preprocess)
dump.inference = copy_table(self.inference)
dump.rmsprop = copy_table(self.rmsprop)
dump.memory = self.experience_pool:dump()
dump.learn_start = self.learn_start
dump.update_freq = self.update_freq
dump.minibatch_size = self.minibatch_size
dump.n_replay = self.n_replay
dump.rescale_r = self.rescale_r
dump.max_reward = self.max_reward
dump.min_reward = self.min_reward
dump.r_max = self.r_max
dump.ep_start = self.ep_start
dump.ep_end = self.ep_end
dump.ep_endt = self.ep_endt
dump.ep_eval = self.ep_eval
dump.ep = self.ep
dump.lr = self.lr
dump.discount = self.discount
dump.clip_delta = self.clip_delta
dump.target_q = self.target_q
dump.wc = self.wc
dump.experienced_steps = self.experienced_steps
dump.learning_epoch = self.learning_epoch
local dump = super.dump(self, cycles)
-- Convert tensors to CPU loadable ones
dump.target_network = self.target_network:clone():type(default_type)
......@@ -325,24 +303,6 @@ function class:dump()
return dump
end
function class:__tostring__()
local result = torch.typename(self) .. "(\n"
local dump = self:dump()
local copy_table = require('pl.tablex').copy
local rw_dump = copy_table(dump)
rw_dump.rmsprop = copy_table(dump.rmsprop)
rw_dump.rmsprop.mean_square = torch.typename(self.rmsprop.mean_square) .. " of size " .. self.rmsprop.mean_square:size(1)
rw_dump.rmsprop.mean = torch.typename(self.rmsprop.mean) .. " of size " .. self.rmsprop.mean:size(1)
rw_dump.target_network = tostring(self.target_network)
rw_dump.preprocess = copy_table(dump.preprocess)
rw_dump.preprocess.network = tostring(self.preprocess.network)
rw_dump.inference = copy_table(dump.inference)
rw_dump.inference.network = tostring(self.inference.network)
rw_dump.memory = tostring(self.experience_pool)
result = result .. require('pl.pretty').write(rw_dump) .. ")"
return result
end
--- Private Methods
-- @section private-methods
......@@ -364,7 +324,7 @@ function class:_init_preprocessing_network(args, observation_size)
-- Reload a dumped agent
preprocess = args
elseif args then
local network = require("network")
local network = require("arcades.network")
if not network[args.class] then
error(string.format(
"Unable to find '%s' in network package",
......@@ -394,9 +354,10 @@ end
-- @tparam[opt] string args.class Name of the class to instantiate
-- @tparam[opt] string args.file Path to a previously dumped network to reload
-- @tparam[opt] table args.params Parameters
function class:_init_inference_network(args)
function class:_init_inference_network(args, dump)
args = args or {}
-- Check that we reload a dump agent, or load a saved file or instantiate a class
assert(args.network or args.file or args.class, "No network was given to the agent.")
assert(args.network or args.file or args.class or dump, "No network was given to the agent.")
--- Main deep neural network.
-- This network is used to get the best action given a history of preprocessed states
-- @tfield nn.Module self.inference.network The actual network
......@@ -405,15 +366,15 @@ function class:_init_inference_network(args)
-- @tfield torch.Tensor self.inference.parameters Flat view of learnable parameters
-- @tfield torch.Tensor self.inference.grad_parameters Flat view of gradient of energy wrt the learnable parameters
-- @table self.inference
local network = require("network")
local network = require("arcades.network")
local inference
if args and args.network then
-- Restore from a dumped agent
inference = args
if dump then
-- Restore from a dumped agent (new architecture 2017-07-06)
inference = dump
elseif args.file then
network = network.load_all() -- luacheck: no unused
self.logger:info('Loading network from ' .. args.file)
self._logger:info('Loading network from ' .. args.file)
local status, result = pcall(torch.load, args.file)
if not status then
error(string.format(
......@@ -436,7 +397,7 @@ function class:_init_inference_network(args)
args.class
), 2)
end
self.logger:info('Creating network from network.' .. args.class)
self._logger:info('Creating network from network.' .. args.class)
inference.network = network[args.class](args.params)
end
inference.network = self._convert_tensor(inference.network)
......@@ -496,7 +457,7 @@ end
function class:_greedy(state)
-- Turn single state into minibatch. Needed for convolutional nets.
if state:dim() == 2 then
self.logger:warn('ConvNet input must be at least 3D. Adding a new dimension of size 1')
self._logger:warn('ConvNet input must be at least 3D. Adding a new dimension of size 1')
state = state:resize(1, state:size(1), state:size(2))
end
......
......@@ -7,19 +7,16 @@
local torch = require('torch')
local logging = require('utils/logging')
local agent = require('agent')
local agent = require('arcades.agent')
assert(agent.BaseAgent)
local module = {}
local class = torch.class('RandomAgent', 'BaseAgent', module)
local class, super = torch.class('RandomAgent', 'BaseAgent', module)
--- Default constructor.
-- @tparam table args
-- @tparam table args.actions Available actions
function class:__init(args)
self.logger = logging.new(logging.timestamp_output("NeuralQLearner"))
self.logger:setLevel((logging.mainLogger and logging.mainLogger.level) or "DEBUG")
super.__init(self, args)
self.num_actions = #args.actions
self.actions_frequencies = torch.zeros(self.num_actions)
self.current_is_terminal = true
......
......@@ -7,12 +7,13 @@
local hash = require('hash')
local torch = require('torch')
local logging = require('utils.logging')
local arcades = require('arcades')
local copy_table = require('pl.tablex').copy
local module = {}
local class = torch.class('_ExperiencePool', module)
local class, super = torch.class('_ExperiencePool', 'ArcadesComponent', module)
--- Attributes
-- @section attributes
......@@ -72,9 +73,9 @@ local class = torch.class('_ExperiencePool', module)
-- @tparam[opt=1] number args.history_length Length of a full state (with current state plus historic ones)
-- @tparam[opt="linear"] string args.history_type Function used to grab historic states (linear or exp)
-- @tparam[opt=1] number args.history_spacing Spacing parameter for @{history_type}
function class:__init(args)
self.logger = logging.new(logging.timestamp_output("ExperiencePool"))
self.logger:setLevel(logging.default_level)
function class:__init(args, dump)
super.__init(self, args, dump)
args = args or {}; dump = dump or {}
self.use_gpu = package.loaded["cutorch"]
--- Function to convert tensors if necessary.
--
......@@ -90,13 +91,12 @@ function class:__init(args)
self._convert_tensor = function(t) return t:type(default_type) end
end
self.states = args.states or {}
if self.use_gpu and args.states then
self.states = self._convert_states(args.states, self._convert_tensor)
end
self.hashed_states = args.hashed_states or 0
self.states = dump.states or {}
self.states = self._convert_states(self.states, self._convert_tensor)
self.hashed_states = dump.hashed_states or 0
self.nil_state = self._convert_tensor(
args.nil_state or
dump.nil_state or
torch.zeros(1):repeatTensor(
table.unpack(
(assert(
......@@ -107,20 +107,21 @@ function class:__init(args)
)
)
self.hasher = hash.XXH64(0)
self._no_dump_list.hasher = true
if args.pool then
if dump.pool then
-- Restore a dumped pool
self.pool = args.pool
self.pool = dump.pool
else
self.pool = {}
self.pool.max_size = args.pool_size or 4096
self:clear()
end
self.pushed_pools = args.pushed_pools or {}
self.pushed_pools = dump.pushed_pools or {}
self.history_length = args.history_length or 1
self.history_type = args.history_type or "linear"
self.history_spacing = args.history_spacing or 1
self.history_length = dump.history_length or 1
self.history_type = dump.history_type or "linear"
self.history_spacing = dump.history_spacing or 1
self.history_offsets = self:_compute_history_offsets()
self.history_stacked_state_size = self.nil_state:size():totable()
......@@ -429,17 +430,9 @@ function class:pop()
return self
end
function class:dump()
function class:dump(cycles)
self:_clean_states()
local dump = {
classname = torch.type(self),
hashed_states = self.hashed_states,
pool = self.pool,
pushed_pools = self.pushed_pools,
history_length = self.history_length,
history_type = self.history_type,
history_spacing = self.history_spacing
}
local dump = super.dump(self, cycles)
local default_type = torch.getdefaulttensortype()
local cpu_converter = function(s) return s:type(default_type) end
dump.states = self._convert_states(self.states, cpu_converter)
......@@ -447,21 +440,6 @@ function class:dump()
return dump
end
function class:__tostring__()
local dump = self:dump()
local result = dump.classname .. "(\n"
dump.nil_state = nil
dump.states = "hashtable of size : " .. dump.hashed_states
dump.pool = {
max_size = dump.pool.max_size,
}
dump.pushed_pools = #dump.pushed_pools .. " saved pools"
result = result .. require('pl.pretty').write(dump) .. ")"
return result
end
function class._convert_states(states, f)
local result = {}
for h, s in pairs(states) do
......
......@@ -11,4 +11,4 @@
-- @alias package
-- @author Alexis BRENON <alexis.brenon@imag.fr>
return require('utils.package_loader')(...)
return require('arcades.utils.package_loader')(...)
......@@ -5,10 +5,9 @@
local torch = require('torch')
local logging = require('utils.logging')
local arcades = require('arcades')
local module = {}
local class = torch.class('BaseEnvironment', module)
local class, super = torch.class('BaseEnvironment', 'ArcadesComponent', module)
--- Attributes
-- @section attributes
......@@ -23,17 +22,8 @@ local class = torch.class('BaseEnvironment', module)
--- Abstract constructor
function class:__init()
if torch.typename(self) == class.__typename then
error(
string.format(
"%s : abstract class is not intented to be instanciated.",
torch.typename(self)
),
2)
else
self.logger = logging.new(logging.timestamp_output(torch.typename(self)))
end
function class:__init(args)
super.__init(self, args)
end
--- Interface
......@@ -102,25 +92,5 @@ function class:reset()
return self
end
--- Return a serializable object representing this environment.
-- @return self
function class:dump()
return {
classname = torch.typename(self)
}
end
function class:__tostring__()
local result = torch.typename(self) .. "(\n"
-- require('mobdebug').start()
local dump = self:dump()
for k, _ in pairs(dump) do
if torch.typename(self[k]) then
dump[k] = tostring(self[k])
end
end
result = result .. require('pl.pretty').write(dump) .. ")"
return result
end
return module.BaseEnvironment
......@@ -8,17 +8,14 @@
local torch = require('torch')
local logging = require('utils.logging')
local environment = require('environment')
local environment = require('arcades.environment')
assert(environment.BaseEnvironment)
local module = {}
local class = torch.class('DebugEnvironment', 'BaseEnvironment', module)
local class, super = torch.class('DebugEnvironment', 'BaseEnvironment', module)
--- Default constructor
function class:__init()
self.logger = logging.new(logging.timestamp_output("DebugEnvironment"))
self.logger:setLevel(logging.default_level)
function class:__init(args)
super.__init(self, args)
self.terminal = nil
end
......@@ -62,10 +59,6 @@ function class:reset()
return self
end
function class:dump()
return nil
end
--- @section end
return module.DebugEnvironment
......@@ -17,4 +17,4 @@
-- @alias package
-- @author Alexis BRENON <alexis.brenon@imag.fr>
return require('utils.package_loader')(...)
return require('arcades.utils.package_loader')(...)
......@@ -13,4 +13,4 @@
-- @alias package
-- @author Alexis BRENON <alexis.brenon@imag.fr>
return require('utils.package_loader')(...)
return require('arcades.utils.package_loader')(...)
local torch = require('torch')
local environment = require('environment')
local environment = require('arcades.environment')
assert(environment.smarthome.sweethome.SweetHome)
local module = {}
local class, super = torch.class('AnnotatedSweetHome', "SweetHome", module)
......
......@@ -2,7 +2,7 @@
local torch = require('torch')
local environment = require('environment')
local environment = require('arcades.environment')
assert(environment.smarthome.sweethome.AnnotatedSweetHome)
local module = {}
......@@ -10,8 +10,9 @@ local class, super = torch.class('GraphicalAnnotatedSweetHome', 'AnnotatedSweetH
function class:__init(args)
args = args or {}
args.environment_model = self
super.__init(self, args)
args.environment_model = self
args.no_action_ratio = args.no_action_ratio or 0
args.no_action_penalty = args.no_action_penalty or (
......@@ -35,12 +36,12 @@ function class:__init(args)
local data_sources = environment.states.datasource
if not args.data_path then
self.logger:info("Using synthetic data")
self._logger:info("Using synthetic data")
self.data_source = data_sources.simulated.SweetHomeDataInferredSimulated(
args
)
else
self.logger:info("Using real data")
self._logger:info("Using real data")
self.data_source = data_sources.labelled.AnnotatedLabelledDataProvider(args)
end
......@@ -113,12 +114,4 @@ function class:get_true_action()
return result
end
function class:dump()
local dump = super.dump(self)
dump.data_renderer = self.data_renderer:dump()
dump.data_source = self.data_source:dump()
dump.reward_function = self.reward_function:dump()
return dump
end
return module.GraphicalAnnotatedSweetHome