Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Alexis Brenon
arcades
Commits
2ed6b3da
Commit
2ed6b3da
authored
Jul 10, 2017
by
Alexis Brenon
Browse files
🔀
Merge branch 'master' into xp/A1-real-annot-nopretrain-1
parents
07f9b52d
ba825191
Changes
7
Hide whitespace changes
Inline
Side-by-side
arcades/agent/NeuralQLearner.lua
View file @
2ed6b3da
...
...
@@ -63,7 +63,7 @@ function class:__init(args, dump)
args
=
args
or
{};
dump
=
dump
or
{}
--- Does the agent must use GPU or not?
self
.
use_gpu
=
package.loaded
[
"cutorch"
]
local
use_gpu
=
package.loaded
[
"cutorch"
]
--- Function to convert tensors if necessary.
--
-- This function must be called to convert tensors/network to the appropriate
...
...
@@ -71,7 +71,7 @@ function class:__init(args, dump)
-- inconsistent types
-- @tfield func self._convert_tensor
self
.
_convert_tensor
=
nil
if
self
.
use_gpu
then
if
use_gpu
then
self
.
_convert_tensor
=
function
(
t
)
return
t
:
cuda
()
end
else
local
default_type
=
torch
.
getdefaulttensortype
()
...
...
arcades/agent/_ExperiencePool.lua
View file @
2ed6b3da
...
...
@@ -76,7 +76,7 @@ local class, super = torch.class('_ExperiencePool', 'ArcadesComponent', module)
function
class
:
__init
(
args
,
dump
)
super
.
__init
(
self
,
args
,
dump
)
args
=
args
or
{};
dump
=
dump
or
{}
self
.
use_gpu
=
package.loaded
[
"cutorch"
]
local
use_gpu
=
package.loaded
[
"cutorch"
]
--- Function to convert tensors if necessary.
--
-- This function must be called to convert tensors/network to the appropriate
...
...
@@ -84,7 +84,7 @@ function class:__init(args, dump)
-- inconsistent types
-- @tfield func self._convert_tensor
self
.
_convert_tensor
=
nil
if
self
.
use_gpu
then
if
use_gpu
then
self
.
_convert_tensor
=
function
(
t
)
return
t
:
cuda
()
end
else
local
default_type
=
torch
.
getdefaulttensortype
()
...
...
@@ -119,9 +119,9 @@ function class:__init(args, dump)
end
self
.
pushed_pools
=
dump
.
pushed_pools
or
{}
self
.
history_length
=
dump
.
history_length
or
1
self
.
history_type
=
dump
.
history_type
or
"linear"
self
.
history_spacing
=
dump
.
history_spacing
or
1
self
.
history_length
=
args
.
history_length
or
1
self
.
history_type
=
args
.
history_type
or
"linear"
self
.
history_spacing
=
args
.
history_spacing
or
1
self
.
history_offsets
=
self
:
_compute_history_offsets
()
self
.
history_stacked_state_size
=
self
.
nil_state
:
size
():
totable
()
...
...
arcades/environment/states/datasource/simulated/SweetHomeDataInferredSimulated.lua
View file @
2ed6b3da
...
...
@@ -72,7 +72,11 @@ function class:get_expected_action()
end
function
class
:
set_state
(
command
,
location
,
activity
)
-- TODO: remove the location field, rely on _location instead
self
.
state
=
{
_command
=
command
,
-- For debug, not use in placeholders
_location
=
location
,
-- For debug, not use in placeholders
_activity
=
activity
,
-- For debug, not use in placeholders
user1inferredcommand
=
command
,
location
=
location
,
-- Just for information, not use in placeholders
user1inferredactivity
=
activity
...
...
arcades/experiment/BaseExperiment.lua
View file @
2ed6b3da
...
...
@@ -155,6 +155,7 @@ function class:report()
_
=
self
.
output
.
report
.
text
and
self
:
_text_report
()
_
=
self
.
output
.
report
.
torch
and
self
:
_torch_report
()
_
=
self
.
output
.
report
.
graphical
and
self
:
_graphical_report
()
_
=
self
.
output
.
report
.
input
>
0
and
self
:
_inputs_report
()
return
self
end
...
...
@@ -509,6 +510,30 @@ function class:_plot_f1_score()
return
self
end
function
class
:
_inputs_report
()
self
.
_logger
:
debug
(
"Saving input frames."
)
if
#
self
.
loop
.
train
.
inputs
>
0
then
image
.
savePNG
(
paths
.
concat
(
self
.
save_path
,
"train_inputs.png"
),
image
.
toDisplayTensor
({
input
=
self
.
loop
.
train
.
inputs
,
padding
=
6
})
)
end
if
#
self
.
loop
.
test
.
inputs
>
0
then
image
.
savePNG
(
paths
.
concat
(
self
.
save_path
,
"test_inputs.png"
),
image
.
toDisplayTensor
({
input
=
self
.
loop
.
test
.
inputs
,
padding
=
6
})
)
end
return
self
end
--- Save a graphical representation of agent network filters
-- @return self
function
class
:
_plot_network_filters
()
...
...
@@ -587,6 +612,7 @@ function class:_train(steps)
confusion_matrix
=
interactions
.
confusion_matrix
,
interactions
=
interactions
.
interactions
,
classification_metrics
=
classification_metrics
,
inputs
=
interactions
.
inputs
,
}
end
...
...
@@ -624,6 +650,7 @@ function class:_test(steps)
confusion_matrix
=
interactions
.
confusion_matrix
,
interactions
=
interactions
.
interactions
,
classification_metrics
=
classification_metrics
,
inputs
=
interactions
.
inputs
,
}
end
...
...
@@ -646,6 +673,7 @@ function class:_interact(environment, steps)
}
local
env_num_actions
=
#
(
environment
:
actions
())
local
confusion_matrix
=
torch
.
zeros
(
env_num_actions
,
env_num_actions
)
local
inputs
=
{}
for
step
=
1
,
steps
do
num_it
=
step
...
...
@@ -672,6 +700,10 @@ function class:_interact(environment, steps)
confusion_matrix
[
true_action
][
action
]
=
confusion_matrix
[
true_action
][
action
]
+
1
end
if
step
>
steps
-
self
.
output
.
report
.
input
then
table.insert
(
inputs
,
state
.
observation
)
end
if
state
.
terminal
then
assert
(
action
==
0
)
environment
:
reset
()
...
...
@@ -682,6 +714,7 @@ function class:_interact(environment, steps)
num_it
=
num_it
,
interactions
=
interactions
,
confusion_matrix
=
confusion_matrix
,
inputs
=
inputs
}
end
...
...
arcades/utils/argparse.lua
View file @
2ed6b3da
...
...
@@ -42,6 +42,7 @@ function class:__init()
{
"field"
,
"text"
,
true
,
"Produce text/log report"
,
"boolean"
},
{
"field"
,
"graphical"
,
true
,
"Produce graphical report"
,
"boolean"
},
{
"field"
,
"torch"
,
true
,
"Produce torch object report"
,
"boolean"
},
{
"field"
,
"input"
,
3
,
"Output some encountered input frames"
,
"number"
},
},
},
{
"field"
,
"metrics"
,
""
,
"Which kind of metrics to report"
,
"table"
,
...
...
arcades/utils/setup.lua
View file @
2ed6b3da
...
...
@@ -122,18 +122,28 @@ local function _environment_setup(args)
envs
.
test
=
envs
.
train
end
if
envs
.
train
and
envs
.
test
then
local
training_feature
=
envs
.
train
:
get_observable_state
().
observation
local
testing_feature
=
envs
.
test
:
get_observable_state
().
observation
assert
(
training_feature
:
isSameSizeAs
(
testing_feature
),
"Training and testing environments must have the same features sizes... Aborting"
)
assert
(
#
envs
.
train
:
actions
()
==
#
envs
.
test
:
actions
(),
"Training and testing environments must accept the same range of actions... Aborting"
)
end
local
training_feature
=
envs
.
train
:
get_observable_state
().
observation
local
testing_feature
=
envs
.
test
:
get_observable_state
().
observation
assert
(
training_feature
:
isSameSizeAs
(
testing_feature
),
"Training and testing environments must have the same features sizes... Aborting"
)
assert
(
#
envs
.
train
:
actions
()
==
#
envs
.
test
:
actions
(),
"Training and testing environments must accept the same range of actions... Aborting"
)
require
(
'image'
).
savePNG
(
paths
.
concat
(
args
.
output
.
path
,
'training_features.png'
),
training_feature
)
require
(
'image'
).
savePNG
(
paths
.
concat
(
args
.
output
.
path
,
'testing_features.png'
),
testing_feature
)
return
envs
end
...
...
assets/domus.svg
View file @
2ed6b3da
...
...
@@ -24,16 +24,16 @@
inkscape:pageopacity=
"0"
inkscape:pageshadow=
"2"
inkscape:window-width=
"1920"
inkscape:window-height=
"1
05
6"
inkscape:window-height=
"1
17
6"
id=
"namedview11423"
showgrid=
"true"
inkscape:zoom=
"0.8344"
inkscape:cx=
"
-15.05
"
inkscape:cx=
"
132.4
"
inkscape:cy=
"0.2405"
inkscape:window-x=
"0"
inkscape:window-x=
"
192
0"
inkscape:window-y=
"24"
inkscape:window-maximized=
"1"
inkscape:current-layer=
"
g18402
"
inkscape:current-layer=
"
domus
"
inkscape:snap-bbox=
"false"
inkscape:bbox-nodes=
"true"
inkscape:snap-nodes=
"true"
...
...
@@ -786,8 +786,7 @@
width=
"100%"
height=
"100%"
xlink:href=
"#command/none"
/></svg:g></svg:g></svg:g><svg:g
id=
"sensors"
style=
"display:none"
><svg:g
id=
"sensors"
><svg:g
id=
"lamp"
><svg:g
transform=
"matrix(2,0,0,2,156,396)"
id=
"g17959"
><svg:title
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment