Commit 7c73acb1 authored by Alexis Brenon's avatar Alexis Brenon
Browse files

Fix experiment dumping and to string

parent 3b124cf1
......@@ -60,6 +60,7 @@ function class:__init(args)
self.eval_steps = args.eval_steps
self.timer = torch.Timer():stop():reset()
self._no_dump_list.timer = true
self.save_at = args.save_at or self.output.save_freq
self.loop = args.loop or {
......@@ -71,6 +72,7 @@ function class:__init(args)
test = {}
}
self.metrics_file = torch.DiskFile(paths.concat(self.output.path, "metrics.dat"), "w")
self._no_dump_list.metrics_file = true
self.metrics_file:writeString(
"#1-step\t2-total_time (s)\t" ..
"3-training_time (s)\t4-training_rate (it/s)\t" ..
......
......@@ -10,8 +10,7 @@ local arcades = require('arcades.utils.package_loader')(...)
local class = torch.class('ArcadesComponent', arcades)
function class:__init(args, dump)
args = args or {}
dump = dump or {}
args = args or {}; dump = dump or {}
self._args = tablex.copy(args)
self._name = string.format(
"%s(%X)",
......@@ -27,9 +26,10 @@ function class:dump(cycles)
cycles = cycles or {}
if self._logger then self._logger:debug("Dumping...") end
cycles[torch.pointer(self)] = true
local args_dump = class._dump(self._args, cycles)
local dump = class._dump(self, cycles)
dump.classname = torch.typename(self)
dump._args = class._dump(self._args, cycles)
dump._args = args_dump
return dump
end
......@@ -38,7 +38,7 @@ function class:__tostring__()
result = string.format(
"%s(\n%s)",
dump.classname,
require('pl.pretty').write(dump._args)
arcades.utils.utils.pretty_repr(dump._args)
)
return result
end
......
......@@ -84,7 +84,7 @@ function class.timestamp_output(self, level, msg)
os.date("%Y%m%dT%H%M%S"),
level,
self.name or debug.getinfo(2, "S").short_src,
debug.getinfo(2, "l").currentline,
debug.getinfo(4, "l").currentline,
msg
)
class.need_lf = false
......
......@@ -67,4 +67,76 @@ function module.load_all()
require('arcades.utils.placeholder').load_all()
end
-- Copy of the new_print Torch function but return a string
-- Special case, don't print big Tensors (more than 25 elements), just their size
local ndepth = 6
function module.pretty_repr(...)
local result = ""
local function rawprint(o)
result = result .. tostring(o or '') .. '\n'
end
local function printtensor(obj)
local size = ""
for j = 1, obj:dim() do
size = size .. obj:size(j)
if j < obj:dim() then size = size .. "x" end
end
return "[" .. torch.type(obj) .. " of size " .. size .. "]"
end
local objs = {...}
local function printrecursive(obj,depth)
local depth = depth or 0
local tab = depth*4
local line = function(s) for i=1,tab do result = result .. ' ' end rawprint(s) end
if next(obj) then
line('{')
tab = tab+2
for k,v in pairs(obj) do
if (
string.match(torch.type(v), "torch%..*Tensor") and
v:numel() > 25
) then
line(tostring(k) .. ' : ' .. printtensor(v))
elseif torch.type(v) == 'table' then
if depth >= (ndepth-1) or next(v) == nil then
line(tostring(k) .. ' : {...}')
else
line(tostring(k) .. ' : ') printrecursive(v,depth+1)
end
else
line(tostring(k) .. ' : ' .. tostring(v))
end
end
tab = tab-2
line('}')
else
line('{}')
end
end
--require('mobdebug').start()
for i = 1,select('#',...) do
local obj = select(i,...)
if type(obj) ~= 'table' then
if (
string.match(torch.type(obj), "torch%..*Tensor") and
obj:numel() > 25
) then
rawprint(printtensor(obj))
elseif type(obj) == 'userdata' or type(obj) == 'cdata' then
rawprint(obj)
else
result = result .. obj .. '\t'
if i == select('#',...) then
rawprint()
end
end
elseif getmetatable(obj) and getmetatable(obj).__tostring then
rawprint(obj)
else
printrecursive(obj)
end
end
return result
end
return module
......@@ -9,20 +9,15 @@
--- @usage
local _ = [[
th ./src/main.lua <options>
th ./main.lua <options>
See src/utils/argparse.lua for a full list of options
See arcades/utils/argparse.lua for a full list of options
]]
local torch = require('torch')
local paths = require('paths')
local basedir = paths.dirname(debug.getinfo(1).short_src)
package.path = string.format(
"%s/?.lua;%s/?/init.lua;./?/init.lua;%s",
basedir, basedir,
package.path
)
package.path = "./?/init.lua;" .. package.path
local argparse = require('arcades.utils.argparse')
local setup = require('arcades.utils.setup')
......@@ -68,4 +63,4 @@ local function main()
experiment:run()
end
return main
main()
......@@ -31,7 +31,7 @@ local function run(args)
str_args = str_args .. string.format(' -%s %q', k, write(v, ""))
end
local command = "th src/main.lua" .. str_args .. persist_command
local command = "th ./main.lua" .. str_args .. persist_command
print(command)
os.execute(string.format("bash -c %q", command))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment