Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Alexis Brenon
arcades
Commits
30404018
Commit
30404018
authored
Sep 22, 2017
by
Alexis Brenon
Browse files
Merge branch 'master' into faulty experiments
Conflicts: expe_template.lua
parents
18df868c
bbe3c83c
Changes
4
Hide whitespace changes
Inline
Side-by-side
arcades/agent/NeuralQLearner.lua
View file @
30404018
...
...
@@ -92,10 +92,6 @@ function class:__init(args, dump)
self
:
_init_inference_network
(
args
.
inference
,
dump
.
inference
)
--- Number of steps after which learning starts.
-- Add delay to populate the experience pool
-- @tfield number self.learn_start
self
.
learn_start
=
args
.
learn_start
or
(
0
.
1
*
self
.
experience_pool
.
pool
.
max_size
)
--- Learning frequency (epoch size).
-- @tfield number self.update_freq
self
.
update_freq
=
args
.
update_freq
or
12
...
...
@@ -105,6 +101,10 @@ function class:__init(args, dump)
--- Number of minibatch learning during a learning epoch.
-- @tfield number self.n_replay
self
.
n_replay
=
args
.
n_replay
or
1
--- Number of steps after which learning starts.
-- Add delay to populate the experience pool
-- @tfield number self.learn_start
self
.
learn_start
=
args
.
learn_start
or
math.max
(
self
.
minibatch_size
+
1
,
(
0
.
1
*
self
.
experience_pool
.
pool
.
max_size
))
--- Scale rewards delta.
-- @tfield boolean self.rescale_r
...
...
arcades/environment/smarthome/sweethome/GraphicalAnnotatedSweetHome.lua
View file @
30404018
...
...
@@ -42,7 +42,7 @@ function class:__init(args)
)
else
self
.
_logger
:
info
(
"Using real data"
)
error
(
"Not yet implemented"
)
self
.
data_source
=
data_sources
.
labelled
.
AnnotatedLabelledDataProvider
(
args
)
end
self
.
reward_function
=
environment
.
states
.
rewardfunction
.
LabelledReward
(
...
...
arcades/environment/states/datasource/labelled/AnnotatedLabelledDataProvider.lua
0 → 100644
View file @
30404018
local
torch
=
require
(
'torch'
)
local
paths
=
require
(
'paths'
)
local
tablex
=
require
(
'pl.tablex'
)
local
module
=
{}
local
environment
=
require
(
'arcades.environment'
)
assert
(
environment
.
states
.
datasource
.
DataSource
)
local
class
,
super
=
torch
.
class
(
'AnnotatedLabelledDataProvider'
,
'DataSource'
,
module
)
function
class
:
__init
(
args
,
dump
)
super
.
__init
(
self
,
args
,
dump
)
args
=
args
or
{};
dump
=
dump
or
{}
self
.
environment_model
=
args
.
environment_model
self
.
data_path
=
args
.
data_path
if
dump
.
data
then
self
.
data
=
dump
.
data
else
self
.
_logger
:
debug
(
"Loading data from '%s'"
,
self
.
data_path
)
self
.
data
=
self
:
_parse_data
(
self
.
data_path
)
end
--- ID (subject index, event index) of the last returned state
-- @within Attributes
-- @tfield number self.last_state[1] Subject ID
-- @tfield number self.last_state[2] Event ID
-- @table self.last_state
self
.
last_state
=
dump
.
last_state
or
{
1
,
0
}
self
.
max_tries
=
args
.
max_tries
or
1
self
.
remaining_tries
=
dump
.
remaining_tries
or
self
.
max_tries
end
--- Parse data from annotation files.
-- @tparam string data_path Path to the data (this can be a unique file or a folder)
-- @treturn table A table containing internal state representations
function
class
:
_parse_data
(
data_path
)
assert
(
data_path
,
"Path to data not given."
)
assert
(
paths
.
dirp
(
data_path
))
local
pool_data
=
{}
local
files
=
{}
for
file_name
in
paths
.
iterfiles
(
data_path
)
do
local
file
=
paths
.
concat
(
data_path
,
file_name
)
local
data
=
{}
for
line
in
io.lines
(
file
)
do
local
user_command
,
user_location
,
user_activity
,
command
,
command_location
=
string.match
(
line
,
"([^\t]+)\t([^\t]+)\t([^\t]+)\t%->\t([^\t]+)\t([^\t]+)"
)
local
state
=
{
user1inferredcommand
=
user_command
,
location
=
user_location
,
-- Just for information, not use in placeholders
user1inferredactivity
=
user_activity
}
for
_
,
loc
in
ipairs
(
self
.
environment_model
.
USER_LOCATIONS
)
do
state
[
"user1inferredlocation_"
..
loc
]
=
false
-- remove old locs
end
state
[
"user1inferredlocation_"
..
user_location
]
=
1
table.insert
(
data
,
{
state
=
state
,
expected_action
=
{
command
,
command_location
}
})
end
table.insert
(
pool_data
,
data
)
end
return
pool_data
end
function
class
:
update
(
action_index
)
if
((
not
self
.
terminal
)
and
(
type
(
action_index
)
==
"number"
and
action_index
>
0
))
then
local
expected_action_index
=
(
self
.
environment_model
.
actions_index
[
self
.
expected_action
[
1
]][
self
.
expected_action
[
2
]]
)
if
(
self
.
remaining_tries
<=
1
or
action_index
==
expected_action_index
)
then
self
.
terminal
=
true
end
self
.
remaining_tries
=
self
.
remaining_tries
-
1
end
return
self
end
function
class
:
reset
(
args
)
args
=
args
or
{}
if
args
.
state_id
then
self
.
state
=
tablex
.
copy
(
self
.
data
[
state_id
[
1
]][
state_id
[
2
]].
state
)
self
.
last_state
=
tablex
.
copy
(
state_id
)
else
self
.
state
=
self
:
_next_state
()
end
self
.
expected_action
=
tablex
.
copy
(
self
.
data
[
self
.
last_state
[
1
]][
self
.
last_state
[
2
]].
expected_action
)
self
.
terminal
=
false
self
.
remaining_tries
=
self
.
max_tries
return
self
end
function
class
:
dump
()
local
dump
=
super
.
dump
(
self
)
dump
.
data_path
=
self
.
data_path
dump
.
data
=
tablex
.
deepcopy
(
self
.
data
)
return
dump
end
function
class
:
__tostring__
()
local
dump
=
self
:
dump
()
local
result
=
torch
.
typename
(
self
)
..
"({data_path = '"
..
self
.
data_path
..
"'})"
return
result
end
--- Get next state according to the parsed data
-- @treturn table The new environment state
function
class
:
_next_state
()
if
self
.
last_state
[
2
]
==
#
(
self
.
data
[
self
.
last_state
[
1
]])
then
self
.
last_state
[
1
]
=
self
.
last_state
[
1
]
+
1
self
.
last_state
[
2
]
=
1
else
self
.
last_state
[
2
]
=
self
.
last_state
[
2
]
+
1
end
if
self
.
last_state
[
1
]
>
#
self
.
data
then
self
.
last_state
[
1
]
=
1
end
return
tablex
.
copy
(
self
.
data
[
self
.
last_state
[
1
]][
self
.
last_state
[
2
]].
state
)
end
return
module
.
AnnotatedLabelledDataProvider
cross-val.sh
View file @
30404018
...
...
@@ -28,9 +28,15 @@ mkdir "${test_dir}"
for
I
in
$(
seq
${
start_fold
}
${
end_fold
}
)
do
mv
"
$(
ls
-1d
"
${
train_dir
}
"
/
*
|
head
-n
${
I
}
|
tail
-n1
)
"
"
${
test_dir
}
"
test_data
=
"
$(
ls
-1d
"
${
train_dir
}
"
/
*
|
head
-n
${
I
}
|
tail
-n1
)
"
mv
"
${
test_data
}
"
"
${
test_dir
}
"
${
cmd
}
echo
"
$(
date
+%Y-%m-%dT%H:%M:%S
)
Running '
${
cmd
}
'"
>>
cross-val-log.txt
echo
"Test data is :
${
test_data
}
"
>>
cross-val-log.txt
${
cmd
}
||
exit
$?
echo
"
$(
date
+%Y-%m-%dT%H:%M:%S
)
End of command."
>>
cross-val-log.txt
mv
-vn
"
${
test_dir
}
"
/
*
"
${
train_dir
}
"
done
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment