Commit 3a858096 authored by Alexis Brenon's avatar Alexis Brenon
Browse files

Add a data provider for real annotated classes

parent ca15d7c8
......@@ -41,7 +41,7 @@ function class:__init(args)
)
else
self.logger:info("Using real data")
error("Not yet implemented")
self.data_source = data_sources.labelled.AnnotatedLabelledDataProvider(args)
end
self.reward_function = environment.states.rewardfunction.LabelledReward(
......
-- TODO:
-- Fix global name: DataSource --> DataProvider
-- Fix subclasses dependance tree
-- Fix subclasses naming scheme
return require('utils.package_loader')(...)
local torch = require('torch')
local paths = require('paths')
local tablex = require('pl.tablex')
local module = {}
local environment = require('environment')
assert(environment.states.datasource.DataSource)
local class, super = torch.class(
'AnnotatedLabelledDataProvider',
'DataSource',
module)
function class:__init(args)
super.__init(self, args)
self.environment_model = args.environment_model
self.data_path = args.data_path
self.logger:debug("Loading data from '%s'", self.data_path)
self.data = self:_parse_data(self.data_path)
--- ID (subject index, event index) of the last returned state
-- @within Attributes
-- @tfield number self.last_state[1] Subject ID
-- @tfield number self.last_state[2] Event ID
-- @table self.last_state
self.last_state = {1,0}
self.max_tries = args.max_tries or 1
self.remaining_tries = args.remaining_tries or self.max_tries
end
--- Parse data from annotation files.
-- @tparam string data_path Path to the data (this can be a unique file or a folder)
-- @treturn table A table containing internal state representations
function class:_parse_data(data_path)
assert(data_path, "Path to data not given.")
assert(paths.dirp(data_path))
local pool_data = {}
local files = {}
for file_name in paths.iterfiles(data_path) do
local file = paths.concat(data_path, file_name)
local data = {}
for line in io.lines(file) do
local user_command, user_location, user_activity,
command, command_location = string.match(
line,
"([^\t]+)\t([^\t]+)\t([^\t]+)\t%->\t([^\t]+)\t([^\t]+)"
)
local state = {
user1inferredcommand = user_command,
location = user_location, -- Just for information, not use in placeholders
user1inferredactivity = user_activity
}
for _, loc in ipairs(self.environment_model.USER_LOCATIONS) do
state["user1inferredlocation_" .. loc] = false -- remove old locs
end
state["user1inferredlocation_"..user_location] = 1
table.insert(data, {
state = state,
expected_action = {command, command_location}
})
end
table.insert(pool_data, data)
end
return pool_data
end
function class:update(action_index)
if ((not self.terminal) and
(type(action_index) == "number" and action_index > 0)) then
local expected_action_index = (
self.environment_model.actions_index[self.expected_action[1]][self.expected_action[2]]
)
if (
self.remaining_tries <= 1 or
action_index == expected_action_index
) then
self.terminal = true
end
self.remaining_tries = self.remaining_tries - 1
end
return self
end
function class:reset(args)
args = args or {}
if args.state_id then
self.state = tablex.copy(self.data[state_id[1]][state_id[2]].state)
self.last_state = tablex.copy(state_id)
else
self.state = self:_next_state()
end
self.expected_action = tablex.copy(
self.data[self.last_state[1]][self.last_state[2]].expected_action
)
self.terminal = false
self.remaining_tries = self.max_tries
return self
end
function class:dump()
local dump = super.dump(self)
dump.data_path = self.data_path
dump.data = tablex.deepcopy(self.data)
return dump
end
function class:__tostring__()
local dump = self:dump()
local result = torch.typename(self) ..
"({data_path = '" .. self.data_path .. "'})"
return result
end
--- Get next state according to the parsed data
-- @treturn table The new environment state
function class:_next_state()
if self.last_state[2] == #(self.data[self.last_state[1]]) then
self.last_state[1] = self.last_state[1] + 1
self.last_state[2] = 1
else
self.last_state[2] = self.last_state[2] + 1
end
if self.last_state[1] > #self.data then
self.last_state[1] = 1
end
return tablex.copy(self.data[self.last_state[1]][self.last_state[2]].state)
end
return module.AnnotatedLabelledDataProvider
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment