Commit 2ed6b3da authored by Alexis Brenon's avatar Alexis Brenon
Browse files

馃攢 Merge branch 'master' into xp/A1-real-annot-nopretrain-1

parents 07f9b52d ba825191
......@@ -63,7 +63,7 @@ function class:__init(args, dump)
args = args or {}; dump = dump or {}
--- Does the agent must use GPU or not?
self.use_gpu = package.loaded["cutorch"]
local use_gpu = package.loaded["cutorch"]
--- Function to convert tensors if necessary.
--
-- This function must be called to convert tensors/network to the appropriate
......@@ -71,7 +71,7 @@ function class:__init(args, dump)
-- inconsistent types
-- @tfield func self._convert_tensor
self._convert_tensor = nil
if self.use_gpu then
if use_gpu then
self._convert_tensor = function(t) return t:cuda() end
else
local default_type = torch.getdefaulttensortype()
......
......@@ -76,7 +76,7 @@ local class, super = torch.class('_ExperiencePool', 'ArcadesComponent', module)
function class:__init(args, dump)
super.__init(self, args, dump)
args = args or {}; dump = dump or {}
self.use_gpu = package.loaded["cutorch"]
local use_gpu = package.loaded["cutorch"]
--- Function to convert tensors if necessary.
--
-- This function must be called to convert tensors/network to the appropriate
......@@ -84,7 +84,7 @@ function class:__init(args, dump)
-- inconsistent types
-- @tfield func self._convert_tensor
self._convert_tensor = nil
if self.use_gpu then
if use_gpu then
self._convert_tensor = function(t) return t:cuda() end
else
local default_type = torch.getdefaulttensortype()
......@@ -119,9 +119,9 @@ function class:__init(args, dump)
end
self.pushed_pools = dump.pushed_pools or {}
self.history_length = dump.history_length or 1
self.history_type = dump.history_type or "linear"
self.history_spacing = dump.history_spacing or 1
self.history_length = args.history_length or 1
self.history_type = args.history_type or "linear"
self.history_spacing = args.history_spacing or 1
self.history_offsets = self:_compute_history_offsets()
self.history_stacked_state_size = self.nil_state:size():totable()
......
......@@ -72,7 +72,11 @@ function class:get_expected_action()
end
function class:set_state(command, location, activity)
-- TODO: remove the location field, rely on _location instead
self.state = {
_command = command, -- For debug, not use in placeholders
_location = location, -- For debug, not use in placeholders
_activity = activity, -- For debug, not use in placeholders
user1inferredcommand = command,
location = location, -- Just for information, not use in placeholders
user1inferredactivity = activity
......
......@@ -155,6 +155,7 @@ function class:report()
_ = self.output.report.text and self:_text_report()
_ = self.output.report.torch and self:_torch_report()
_ = self.output.report.graphical and self:_graphical_report()
_ = self.output.report.input > 0 and self:_inputs_report()
return self
end
......@@ -509,6 +510,30 @@ function class:_plot_f1_score()
return self
end
function class:_inputs_report()
self._logger:debug("Saving input frames.")
if #self.loop.train.inputs > 0 then
image.savePNG(
paths.concat(self.save_path, "train_inputs.png"),
image.toDisplayTensor({
input = self.loop.train.inputs,
padding = 6
})
)
end
if #self.loop.test.inputs > 0 then
image.savePNG(
paths.concat(self.save_path, "test_inputs.png"),
image.toDisplayTensor({
input = self.loop.test.inputs,
padding = 6
})
)
end
return self
end
--- Save a graphical representation of agent network filters
-- @return self
function class:_plot_network_filters()
......@@ -587,6 +612,7 @@ function class:_train(steps)
confusion_matrix = interactions.confusion_matrix,
interactions = interactions.interactions,
classification_metrics = classification_metrics,
inputs = interactions.inputs,
}
end
......@@ -624,6 +650,7 @@ function class:_test(steps)
confusion_matrix = interactions.confusion_matrix,
interactions = interactions.interactions,
classification_metrics = classification_metrics,
inputs = interactions.inputs,
}
end
......@@ -646,6 +673,7 @@ function class:_interact(environment, steps)
}
local env_num_actions = #(environment:actions())
local confusion_matrix = torch.zeros(env_num_actions, env_num_actions)
local inputs = {}
for step = 1,steps do
num_it = step
......@@ -672,6 +700,10 @@ function class:_interact(environment, steps)
confusion_matrix[true_action][action] = confusion_matrix[true_action][action] + 1
end
if step > steps - self.output.report.input then
table.insert(inputs, state.observation)
end
if state.terminal then
assert(action == 0)
environment:reset()
......@@ -682,6 +714,7 @@ function class:_interact(environment, steps)
num_it = num_it,
interactions = interactions,
confusion_matrix = confusion_matrix,
inputs = inputs
}
end
......
......@@ -42,6 +42,7 @@ function class:__init()
{"field", "text", true, "Produce text/log report", "boolean"},
{"field", "graphical", true, "Produce graphical report", "boolean"},
{"field", "torch", true, "Produce torch object report", "boolean"},
{"field", "input", 3, "Output some encountered input frames", "number"},
},
},
{"field", "metrics", "", "Which kind of metrics to report", "table",
......
......@@ -122,18 +122,28 @@ local function _environment_setup(args)
envs.test = envs.train
end
if envs.train and envs.test then
local training_feature = envs.train:get_observable_state().observation
local testing_feature = envs.test:get_observable_state().observation
assert(
training_feature:isSameSizeAs(testing_feature),
"Training and testing environments must have the same features sizes... Aborting"
)
assert(
#envs.train:actions() == #envs.test:actions(),
"Training and testing environments must accept the same range of actions... Aborting"
)
end
local training_feature = envs.train:get_observable_state().observation
local testing_feature = envs.test:get_observable_state().observation
assert(
training_feature:isSameSizeAs(testing_feature),
"Training and testing environments must have the same features sizes... Aborting"
)
assert(
#envs.train:actions() == #envs.test:actions(),
"Training and testing environments must accept the same range of actions... Aborting"
)
require('image').savePNG(
paths.concat(
args.output.path,
'training_features.png'
),
training_feature)
require('image').savePNG(
paths.concat(
args.output.path,
'testing_features.png'
),
testing_feature)
return envs
end
......
......@@ -24,16 +24,16 @@
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1056"
inkscape:window-height="1176"
id="namedview11423"
showgrid="true"
inkscape:zoom="0.8344"
inkscape:cx="-15.05"
inkscape:cx="132.4"
inkscape:cy="0.2405"
inkscape:window-x="0"
inkscape:window-x="1920"
inkscape:window-y="24"
inkscape:window-maximized="1"
inkscape:current-layer="g18402"
inkscape:current-layer="domus"
inkscape:snap-bbox="false"
inkscape:bbox-nodes="true"
inkscape:snap-nodes="true"
......@@ -786,8 +786,7 @@
width="100%"
height="100%"
xlink:href="#command/none" /></svg:g></svg:g></svg:g><svg:g
id="sensors"
style="display:none"><svg:g
id="sensors"><svg:g
id="lamp"><svg:g
transform="matrix(2,0,0,2,156,396)"
id="g17959"><svg:title
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment