AnnotatedLabelledDataProvider.lua 3.88 KB
Newer Older
1
2
3
4
5
6
local torch = require('torch')
local paths = require('paths')

local tablex = require('pl.tablex')

local module = {}
7
local environment = require('arcades.environment')
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
assert(environment.states.datasource.DataSource)
local class, super = torch.class(
  'AnnotatedLabelledDataProvider',
  'DataSource',
  module)

function class:__init(args)
  super.__init(self, args)
  self.environment_model = args.environment_model
  self.data_path = args.data_path
  self.logger:debug("Loading data from '%s'", self.data_path)
  self.data = self:_parse_data(self.data_path)
  --- ID (subject index, event index) of the last returned state
  -- @within Attributes
  -- @tfield number self.last_state[1] Subject ID
  -- @tfield number self.last_state[2] Event ID
  -- @table self.last_state
  self.last_state = {1,0}
  self.max_tries = args.max_tries or 1
  self.remaining_tries = args.remaining_tries or self.max_tries
end


--- Parse data from annotation files.
-- @tparam string data_path Path to the data (this can be a unique file or a folder)
-- @treturn table A table containing internal state representations
function class:_parse_data(data_path)
  assert(data_path, "Path to data not given.")
  assert(paths.dirp(data_path))
  
  local pool_data = {}

  local files = {}
  for file_name in paths.iterfiles(data_path) do
    local file = paths.concat(data_path, file_name)
    local data = {}
    for line in io.lines(file) do
      local user_command, user_location, user_activity,
        command, command_location = string.match(
          line,
          "([^\t]+)\t([^\t]+)\t([^\t]+)\t%->\t([^\t]+)\t([^\t]+)"
        )    
      local state = {
        user1inferredcommand = user_command,
        location = user_location, -- Just for information, not use in placeholders
        user1inferredactivity = user_activity
      }
      for _, loc in ipairs(self.environment_model.USER_LOCATIONS) do
        state["user1inferredlocation_" .. loc] = false -- remove old locs
      end
      state["user1inferredlocation_"..user_location] = 1
      table.insert(data, {
        state = state,
        expected_action = {command, command_location}
      })
    end
    table.insert(pool_data, data)
  end

  return pool_data
end

function class:update(action_index)
  if ((not self.terminal) and
    (type(action_index) == "number" and action_index > 0)) then
    local expected_action_index = (
      self.environment_model.actions_index[self.expected_action[1]][self.expected_action[2]]
    )
    if (
      self.remaining_tries <= 1 or
      action_index == expected_action_index
      ) then
      self.terminal = true
    end
    self.remaining_tries = self.remaining_tries - 1
  end
  return self
end

function class:reset(args)
  args = args or {}
  if args.state_id then
    self.state = tablex.copy(self.data[state_id[1]][state_id[2]].state)
    self.last_state = tablex.copy(state_id)
  else
    self.state = self:_next_state()
  end
  self.expected_action = tablex.copy(
    self.data[self.last_state[1]][self.last_state[2]].expected_action
  )
  self.terminal = false
  self.remaining_tries = self.max_tries
  return self
end

function class:dump()
  local dump = super.dump(self)
  dump.data_path = self.data_path
  dump.data = tablex.deepcopy(self.data)
  return dump
end

function class:__tostring__()
  local dump = self:dump()
  local result = torch.typename(self) ..
    "({data_path = '" .. self.data_path .. "'})"
  return result
end


--- Get next state according to the parsed data
-- @treturn table The new environment state
function class:_next_state()
  if self.last_state[2] == #(self.data[self.last_state[1]]) then
    self.last_state[1] = self.last_state[1] + 1
    self.last_state[2] = 1
  else
    self.last_state[2] = self.last_state[2] + 1
  end

  if self.last_state[1] > #self.data then
    self.last_state[1] = 1
  end

  return tablex.copy(self.data[self.last_state[1]][self.last_state[2]].state)
end

return module.AnnotatedLabelledDataProvider