Commit 7075e7ef authored by Alexis Brenon's avatar Alexis Brenon
Browse files

Add some debug of the input frames

parent 8da75836
......@@ -155,6 +155,7 @@ function class:report()
_ = self.output.report.text and self:_text_report()
_ = self.output.report.torch and self:_torch_report()
_ = self.output.report.graphical and self:_graphical_report()
_ = self.output.report.input > 0 and self:_inputs_report()
return self
end
......@@ -509,6 +510,30 @@ function class:_plot_f1_score()
return self
end
function class:_inputs_report()
self.logger:debug("Saving input frames.")
if #self.loop.train.inputs > 0 then
image.savePNG(
paths.concat(self.save_path, "train_inputs.png"),
image.toDisplayTensor({
input = self.loop.train.inputs,
padding = 6
})
)
end
if #self.loop.test.inputs > 0 then
image.savePNG(
paths.concat(self.save_path, "test_inputs.png"),
image.toDisplayTensor({
input = self.loop.test.inputs,
padding = 6
})
)
end
return self
end
--- Save a graphical representation of agent network filters
-- @return self
function class:_plot_network_filters()
......@@ -587,6 +612,7 @@ function class:_train(steps)
confusion_matrix = interactions.confusion_matrix,
interactions = interactions.interactions,
classification_metrics = classification_metrics,
inputs = interactions.inputs,
}
end
......@@ -624,6 +650,7 @@ function class:_test(steps)
confusion_matrix = interactions.confusion_matrix,
interactions = interactions.interactions,
classification_metrics = classification_metrics,
inputs = interactions.inputs,
}
end
......@@ -646,6 +673,7 @@ function class:_interact(environment, steps)
}
local env_num_actions = #(environment:actions())
local confusion_matrix = torch.zeros(env_num_actions, env_num_actions)
local inputs = {}
for step = 1,steps do
num_it = step
......@@ -672,6 +700,10 @@ function class:_interact(environment, steps)
confusion_matrix[true_action][action] = confusion_matrix[true_action][action] + 1
end
if step > steps - self.output.report.input then
table.insert(inputs, state.observation)
end
if state.terminal then
assert(action == 0)
environment:reset()
......@@ -682,6 +714,7 @@ function class:_interact(environment, steps)
num_it = num_it,
interactions = interactions,
confusion_matrix = confusion_matrix,
inputs = inputs
}
end
......
......@@ -42,6 +42,7 @@ function class:__init()
{"field", "text", true, "Produce text/log report", "boolean"},
{"field", "graphical", true, "Produce graphical report", "boolean"},
{"field", "torch", true, "Produce torch object report", "boolean"},
{"field", "input", 3, "Output some encountered input frames", "number"},
},
},
{"field", "metrics", "", "Which kind of metrics to report", "table",
......
......@@ -122,18 +122,28 @@ local function _environment_setup(args)
envs.test = envs.train
end
if envs.train and envs.test then
local training_feature = envs.train:get_observable_state().observation
local testing_feature = envs.test:get_observable_state().observation
assert(
training_feature:isSameSizeAs(testing_feature),
"Training and testing environments must have the same features sizes... Aborting"
)
assert(
#envs.train:actions() == #envs.test:actions(),
"Training and testing environments must accept the same range of actions... Aborting"
)
end
local training_feature = envs.train:get_observable_state().observation
local testing_feature = envs.test:get_observable_state().observation
assert(
training_feature:isSameSizeAs(testing_feature),
"Training and testing environments must have the same features sizes... Aborting"
)
assert(
#envs.train:actions() == #envs.test:actions(),
"Training and testing environments must accept the same range of actions... Aborting"
)
require('image').savePNG(
paths.concat(
args.output.path,
'training_features.png'
),
training_feature)
require('image').savePNG(
paths.concat(
args.output.path,
'testing_features.png'
),
testing_feature)
return envs
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment