Commit 00141f64 authored by Alexis Brenon's avatar Alexis Brenon
Browse files

馃殤 Fix dump reloading

Use a more explicit argument name (so called dump) and recurcively load
arguments if necessary.
parent ed114a0f
......@@ -93,14 +93,28 @@ function class.load(dump)
end
end
local class = find_class_by_name(dump.classname, arcades)
if not class then
error(string.format(
"Unable to find '%s' class in arcades hierarchy.",
dump.classname
))
local function recursive_load(dump)
local class = find_class_by_name(dump.classname, arcades)
if not class then
error(string.format(
"Unable to find '%s' class in arcades hierarchy.",
dump.classname
))
end
for k, v in pairs(dump._args) do
if (
type(v) == "table" and
v.classname and v._args
) then
dump._args[k] = recursive_load(v)
end
end
return class(dump._args, dump)
end
return class(dump._args, dump)
return recursive_load(dump)
end
return arcades
......@@ -70,7 +70,7 @@ function class:__init()
fields = {
{"field", "class", "", "Environment class to use", "string"},
{"field", "params", "", "Environment parameters (see class documentation)", "table"},
{"field", "file", "", "File to load as environment dump", "string"},
{"field", "dump", "", "File to load as environment dump", "string"},
}
},
......@@ -78,7 +78,7 @@ function class:__init()
fields = {
{"field", "class", "", "Agent class to use", "string"},
{"field", "params", "", "Agent parameters (see class documentation)", "table"},
{"field", "file", "", "File to load as agent dump", "string"},
{"field", "dump", "", "File to load as agent dump", "string"},
}
},
......@@ -86,7 +86,7 @@ function class:__init()
fields = {
{"field", "class", "", "Experiment class to use", "string"},
{"field", "params", "", "Experiment parameters (see class documentation)", "table"},
{"field", "file", "", "File to load as experiment dump", "string"},
{"field", "dump", "", "File to load as experiment dump", "string"},
}
},
}
......
......@@ -107,13 +107,11 @@ local function _environment_setup(args)
local env_class, err = _get_class(env_args.class, environment)
if err then error(err) end
envs[env] = env_class(env_args.params):reset()
elseif env_args.file and env_args.file ~= "" then
Logger.main_logger:debug("Loading an environment dump: " .. env_args.file)
utils.load_declarations()
local dump = torch.load(env_args.file)
local env_class, err = _get_class(dump.classname, environment)
if err then error(err) end
envs[env] = env_class(dump)
elseif env_args.dump and env_args.dump ~= "" then
Logger.main_logger:debug("Loading an environment dump: " .. env_args.dump)
envs[env] = arcades.ArcadesComponent.load(
torch.load(env_args.dump)
)
end
end
end
......@@ -171,13 +169,11 @@ local function _agent_setup(args, environments)
)
end
agent = ag_class(ag_args.params)
elseif ag_args.file and ag_args.file ~= "" then
logging.mainLogger:debug("Loading an agent dump: " .. ag_args.file)
utils.load_declarations()
local dump = torch.load(ag_args.file)
local ag_class, err = _get_class(dump.classname, require('arcades.agent'))
if err then error(err) end
agent = ag_class(dump)
elseif ag_args.dump and ag_args.dump ~= "" then
Logger.main_logger:debug("Loading an agent dump: " .. ag_args.dump)
agent = arcades.ArcadesComponent.load(
torch.load(ag_args.dump)
)
end
end
......@@ -196,21 +192,18 @@ local function _experiment_setup(args, agent, environments)
exp_args.params.agent = agent
exp_args.params.output = args.output
experiment = exp_class(exp_args.params)
elseif exp_args.file and exp_args.file ~= "" then
logging.mainLogger:debug("Loading an experiment dump: " .. exp_args.file)
utils.load_declarations()
local dump = torch.load(exp_args.file)
dump.environment = _environment_setup({
training_environment = {file = dump.environment.train},
testing_environment = {file = dump.environment.test}
elseif exp_args.dump and exp_args.dump ~= "" then
Logger.main_logger:debug("Loading an experiment dump: " .. exp_args.dump)
local dump = torch.load(exp_args.dump)
dump._args.environment = _environment_setup({
training_environment = {dump = dump.environment.train},
testing_environment = {dump = dump.environment.test}
})
dump.agent = _agent_setup(
{agent = {file = dump.agent}},
dump.environment
dump._args.agent = _agent_setup(
{agent = {dump = dump.agent}},
dump._args.environment
)
local exp_class, err = _get_class(dump.classname, require('arcades.experiment'))
if err then error(err) end
experiment = exp_class(dump)
experiment = arcades.ArcadesComponent.load(dump)
end
end
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment