Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Alexis Brenon
arcades
Commits
7075e7ef
Commit
7075e7ef
authored
Jul 05, 2017
by
Alexis Brenon
Browse files
✨
Add some debug of the input frames
parent
8da75836
Changes
3
Hide whitespace changes
Inline
Side-by-side
arcades/experiment/BaseExperiment.lua
View file @
7075e7ef
...
...
@@ -155,6 +155,7 @@ function class:report()
_
=
self
.
output
.
report
.
text
and
self
:
_text_report
()
_
=
self
.
output
.
report
.
torch
and
self
:
_torch_report
()
_
=
self
.
output
.
report
.
graphical
and
self
:
_graphical_report
()
_
=
self
.
output
.
report
.
input
>
0
and
self
:
_inputs_report
()
return
self
end
...
...
@@ -509,6 +510,30 @@ function class:_plot_f1_score()
return
self
end
function
class
:
_inputs_report
()
self
.
logger
:
debug
(
"Saving input frames."
)
if
#
self
.
loop
.
train
.
inputs
>
0
then
image
.
savePNG
(
paths
.
concat
(
self
.
save_path
,
"train_inputs.png"
),
image
.
toDisplayTensor
({
input
=
self
.
loop
.
train
.
inputs
,
padding
=
6
})
)
end
if
#
self
.
loop
.
test
.
inputs
>
0
then
image
.
savePNG
(
paths
.
concat
(
self
.
save_path
,
"test_inputs.png"
),
image
.
toDisplayTensor
({
input
=
self
.
loop
.
test
.
inputs
,
padding
=
6
})
)
end
return
self
end
--- Save a graphical representation of agent network filters
-- @return self
function
class
:
_plot_network_filters
()
...
...
@@ -587,6 +612,7 @@ function class:_train(steps)
confusion_matrix
=
interactions
.
confusion_matrix
,
interactions
=
interactions
.
interactions
,
classification_metrics
=
classification_metrics
,
inputs
=
interactions
.
inputs
,
}
end
...
...
@@ -624,6 +650,7 @@ function class:_test(steps)
confusion_matrix
=
interactions
.
confusion_matrix
,
interactions
=
interactions
.
interactions
,
classification_metrics
=
classification_metrics
,
inputs
=
interactions
.
inputs
,
}
end
...
...
@@ -646,6 +673,7 @@ function class:_interact(environment, steps)
}
local
env_num_actions
=
#
(
environment
:
actions
())
local
confusion_matrix
=
torch
.
zeros
(
env_num_actions
,
env_num_actions
)
local
inputs
=
{}
for
step
=
1
,
steps
do
num_it
=
step
...
...
@@ -672,6 +700,10 @@ function class:_interact(environment, steps)
confusion_matrix
[
true_action
][
action
]
=
confusion_matrix
[
true_action
][
action
]
+
1
end
if
step
>
steps
-
self
.
output
.
report
.
input
then
table.insert
(
inputs
,
state
.
observation
)
end
if
state
.
terminal
then
assert
(
action
==
0
)
environment
:
reset
()
...
...
@@ -682,6 +714,7 @@ function class:_interact(environment, steps)
num_it
=
num_it
,
interactions
=
interactions
,
confusion_matrix
=
confusion_matrix
,
inputs
=
inputs
}
end
...
...
arcades/utils/argparse.lua
View file @
7075e7ef
...
...
@@ -42,6 +42,7 @@ function class:__init()
{
"field"
,
"text"
,
true
,
"Produce text/log report"
,
"boolean"
},
{
"field"
,
"graphical"
,
true
,
"Produce graphical report"
,
"boolean"
},
{
"field"
,
"torch"
,
true
,
"Produce torch object report"
,
"boolean"
},
{
"field"
,
"input"
,
3
,
"Output some encountered input frames"
,
"number"
},
},
},
{
"field"
,
"metrics"
,
""
,
"Which kind of metrics to report"
,
"table"
,
...
...
arcades/utils/setup.lua
View file @
7075e7ef
...
...
@@ -122,18 +122,28 @@ local function _environment_setup(args)
envs
.
test
=
envs
.
train
end
if
envs
.
train
and
envs
.
test
then
local
training_feature
=
envs
.
train
:
get_observable_state
().
observation
local
testing_feature
=
envs
.
test
:
get_observable_state
().
observation
assert
(
training_feature
:
isSameSizeAs
(
testing_feature
),
"Training and testing environments must have the same features sizes... Aborting"
)
assert
(
#
envs
.
train
:
actions
()
==
#
envs
.
test
:
actions
(),
"Training and testing environments must accept the same range of actions... Aborting"
)
end
local
training_feature
=
envs
.
train
:
get_observable_state
().
observation
local
testing_feature
=
envs
.
test
:
get_observable_state
().
observation
assert
(
training_feature
:
isSameSizeAs
(
testing_feature
),
"Training and testing environments must have the same features sizes... Aborting"
)
assert
(
#
envs
.
train
:
actions
()
==
#
envs
.
test
:
actions
(),
"Training and testing environments must accept the same range of actions... Aborting"
)
require
(
'image'
).
savePNG
(
paths
.
concat
(
args
.
output
.
path
,
'training_features.png'
),
training_feature
)
require
(
'image'
).
savePNG
(
paths
.
concat
(
args
.
output
.
path
,
'testing_features.png'
),
testing_feature
)
return
envs
end
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment