Commit 30404018 authored by Alexis Brenon's avatar Alexis Brenon
Browse files

Merge branch 'master' into faulty experiments

Conflicts:
	expe_template.lua
parents 18df868c bbe3c83c
......@@ -92,10 +92,6 @@ function class:__init(args, dump)
self:_init_inference_network(args.inference, dump.inference)
--- Number of steps after which learning starts.
-- Add delay to populate the experience pool
-- @tfield number self.learn_start
self.learn_start = args.learn_start or (0.1 * self.experience_pool.pool.max_size)
--- Learning frequency (epoch size).
-- @tfield number self.update_freq
self.update_freq = args.update_freq or 12
......@@ -105,6 +101,10 @@ function class:__init(args, dump)
--- Number of minibatch learning during a learning epoch.
-- @tfield number self.n_replay
self.n_replay = args.n_replay or 1
--- Number of steps after which learning starts.
-- Add delay to populate the experience pool
-- @tfield number self.learn_start
self.learn_start = args.learn_start or math.max(self.minibatch_size + 1, (0.1 * self.experience_pool.pool.max_size))
--- Scale rewards delta.
-- @tfield boolean self.rescale_r
......
......@@ -42,7 +42,7 @@ function class:__init(args)
)
else
self._logger:info("Using real data")
error("Not yet implemented")
self.data_source = data_sources.labelled.AnnotatedLabelledDataProvider(args)
end
self.reward_function = environment.states.rewardfunction.LabelledReward(
......
local torch = require('torch')
local paths = require('paths')
local tablex = require('pl.tablex')
local module = {}
local environment = require('arcades.environment')
assert(environment.states.datasource.DataSource)
local class, super = torch.class(
'AnnotatedLabelledDataProvider',
'DataSource',
module)
function class:__init(args, dump)
super.__init(self, args, dump)
args = args or {}; dump = dump or {}
self.environment_model = args.environment_model
self.data_path = args.data_path
if dump.data then
self.data = dump.data
else
self._logger:debug("Loading data from '%s'", self.data_path)
self.data = self:_parse_data(self.data_path)
end
--- ID (subject index, event index) of the last returned state
-- @within Attributes
-- @tfield number self.last_state[1] Subject ID
-- @tfield number self.last_state[2] Event ID
-- @table self.last_state
self.last_state = dump.last_state or {1,0}
self.max_tries = args.max_tries or 1
self.remaining_tries = dump.remaining_tries or self.max_tries
end
--- Parse data from annotation files.
-- @tparam string data_path Path to the data (this can be a unique file or a folder)
-- @treturn table A table containing internal state representations
function class:_parse_data(data_path)
assert(data_path, "Path to data not given.")
assert(paths.dirp(data_path))
local pool_data = {}
local files = {}
for file_name in paths.iterfiles(data_path) do
local file = paths.concat(data_path, file_name)
local data = {}
for line in io.lines(file) do
local user_command, user_location, user_activity,
command, command_location = string.match(
line,
"([^\t]+)\t([^\t]+)\t([^\t]+)\t%->\t([^\t]+)\t([^\t]+)"
)
local state = {
user1inferredcommand = user_command,
location = user_location, -- Just for information, not use in placeholders
user1inferredactivity = user_activity
}
for _, loc in ipairs(self.environment_model.USER_LOCATIONS) do
state["user1inferredlocation_" .. loc] = false -- remove old locs
end
state["user1inferredlocation_"..user_location] = 1
table.insert(data, {
state = state,
expected_action = {command, command_location}
})
end
table.insert(pool_data, data)
end
return pool_data
end
function class:update(action_index)
if ((not self.terminal) and
(type(action_index) == "number" and action_index > 0)) then
local expected_action_index = (
self.environment_model.actions_index[self.expected_action[1]][self.expected_action[2]]
)
if (
self.remaining_tries <= 1 or
action_index == expected_action_index
) then
self.terminal = true
end
self.remaining_tries = self.remaining_tries - 1
end
return self
end
function class:reset(args)
args = args or {}
if args.state_id then
self.state = tablex.copy(self.data[state_id[1]][state_id[2]].state)
self.last_state = tablex.copy(state_id)
else
self.state = self:_next_state()
end
self.expected_action = tablex.copy(
self.data[self.last_state[1]][self.last_state[2]].expected_action
)
self.terminal = false
self.remaining_tries = self.max_tries
return self
end
function class:dump()
local dump = super.dump(self)
dump.data_path = self.data_path
dump.data = tablex.deepcopy(self.data)
return dump
end
function class:__tostring__()
local dump = self:dump()
local result = torch.typename(self) ..
"({data_path = '" .. self.data_path .. "'})"
return result
end
--- Get next state according to the parsed data
-- @treturn table The new environment state
function class:_next_state()
if self.last_state[2] == #(self.data[self.last_state[1]]) then
self.last_state[1] = self.last_state[1] + 1
self.last_state[2] = 1
else
self.last_state[2] = self.last_state[2] + 1
end
if self.last_state[1] > #self.data then
self.last_state[1] = 1
end
return tablex.copy(self.data[self.last_state[1]][self.last_state[2]].state)
end
return module.AnnotatedLabelledDataProvider
......@@ -28,9 +28,15 @@ mkdir "${test_dir}"
for I in $(seq ${start_fold} ${end_fold})
do
mv "$(ls -1d "${train_dir}"/* | head -n ${I} | tail -n1)" "${test_dir}"
test_data="$(ls -1d "${train_dir}"/* | head -n ${I} | tail -n1)"
mv "${test_data}" "${test_dir}"
${cmd}
echo "$(date +%Y-%m-%dT%H:%M:%S) Running '${cmd}'" >> cross-val-log.txt
echo "Test data is : ${test_data}" >> cross-val-log.txt
${cmd} || exit $?
echo "$(date +%Y-%m-%dT%H:%M:%S) End of command." >> cross-val-log.txt
mv -vn "${test_dir}"/* "${train_dir}"
done
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment