Vous avez re莽u un message "Your GitLab account has been locked ..." ? Pas d'inqui茅tude : lisez cet article https://docs.gricad-pages.univ-grenoble-alpes.fr/help/unlock/

Commit 5e010faf authored by Alexis Brenon's avatar Alexis Brenon
Browse files

馃摑 Document experiment package

parent 9ba1b62b
--- Classical experiment class.
--- Classical (monitored) experiment class.
--
-- This class can be inherited to add more features
-- @classmod experiment.BaseExperiment
-- @alias class
-- @inherit true
-- @see ArcadesComponent
-- @mtodo Refactor class following `ArcadesComponent` architecture.
-- @author Alexis BRENON <alexis.brenon@imag.fr>
local paths = require('paths')
......@@ -15,39 +18,108 @@ local Logger = require('arcades.utils.Logger')
local module = {}
local class, super = torch.class('BaseExperiment', 'ArcadesComponent', module)
--- Attributes
-- @section attributes
--- Experiment output configuration
-- @tfield table self.output
--- Total number of steps to do
--- Data Types
-- @section data-types
--- Argument used for instanciation.
-- @tfield {train=environment.BaseEnvironment,test=environment.BaseEnvironment} environment Table with `train` and `test` @{environment.BaseEnvironment|environments}
-- @tfield agent.BaseAgent agent The @{agent.BaseAgent|agent} to use
-- @tfield OutputArgument output Output options
-- @tfield[opt=math.huge] number steps Total number of steps to do (excluding evaluation steps)
-- @tfield number eval_freq Number of steps between two evaluations
-- @tfield number eval_steps Number of evaluation steps
-- @tfield[opt] number save_at Next interation at which to save the experiment
-- @tfield[opt] {train=table,test=table} loop Results of the last iteration
-- @tfield[opt] {train=table,test=table} metrics Table of saved metrics
-- @tfield[opt] number step Current step
-- @mtodo Extract some arguments to `Dump`.
-- @table InitArgument
--- Dump.
-- @table Dump
--- Arguments to describe the output.
-- @table OutputArgument
--- Result of a set of interactions.
-- @tfield integer num_it Actual number of interactions done
-- @tfield {real=number,sys=number,user=number} time Times elapsed by interactions
-- @tfield integer num_rewards Number of non-zero rewards received
-- @tfield integer num_episodes Number of episodes (terminal states)
-- @tfield number total_reward Sum of rewards obtained
-- @tfield torch.Tensor/tensor.md/ confusion_matrix Confusion matrix
-- @tfield InteractionsTable interactions Details of interactions
-- @tfield ClassificationTable classification_metrics Sumed up metrics
-- @tfield {torch.Tensor/tensor.md/,...} inputs Inputs of the interactions
-- @table InteractionsResult
--- A list of interactions.
-- @tfield torch.Tensor/tensor.md/ expected_actions Actions expected by the environment
-- @tfield torch.Tensor/tensor.md/ actions Actions executed
-- @tfield torch.Tensor/tensor.md/ rewards Rewards obtained
-- @tfield torch.Tensor/tensor.md/ terminals Terminal signal of the input state
-- @table InteractionsTable
--- List of classification metrics computed.
--
-- See some explanations about metrics on this site:
-- [`http://blog.revolutionanalytics.com/2016/03/com_class_eval_metrics_r.html`](http://blog.revolutionanalytics.com/2016/03/com_class_eval_metrics_r.html)
-- @tfield number accuracy
-- @tfield number average_accuracy
-- @tfield number precision
-- @tfield number recall
-- @tfield number f1
-- @tfield number macro_precision
-- @tfield number macro_recall
-- @tfield number macro_f1
-- @tfield number micro_precision
-- @tfield number micro_recall
-- @tfield number micro_f1
-- @table ClassificationTable
--- Fields
-- @section fields
--- Table with `train` and `test` @{environment.BaseEnvironment|environments}
-- @tfield {train=environment.BaseEnvironment,test=environment.BaseEnvironment} self.environment
--- The @{agent.BaseAgent|agent} to use
-- @tfield agent.BaseAgent self.agent
--- Output options
-- @tfield OutputArgument self.output
--- Total number of steps to do (excluding evaluation steps)
-- @tfield number self.steps
--- Number of steps between two evaluation phases
--- Number of steps between two evaluations
-- @tfield number self.eval_freq
--- Number of steps in an evaluation phase
--- Number of evaluation steps
-- @tfield number self.eval_steps
--- Timer to measure the full running time
-- @tfield torch.Timer self.timer
--- A @{torch.Timer/timer.md/|Timer} used to time the experiment.
-- @tfield torch.Timer/timer.md/ self.timer
--- Next interation at which to save the experiment
-- @tfield number self.save_at
--- Results of the last iteration
-- @tfield {train=table,test=table} self.loop
--- Table of saved metrics
-- @tfield {train=table,test=table} self.metrics
--- @{torch.DiskFile/diskfile.md/|File} used as output.
-- @tfield torch.DiskFile/diskfile.md/ self.metrics_file
--- Current step
-- @tfield number self.step
--- @section end
--- Default constructor.
-- @tparam table args
-- @tparam table args.environment Table with <code>train</code> and <code>test</code> environments
-- @param args.agent The agent to use
-- @tparam table args.output Output options (see @{main})
-- @tparam[opt=math.huge] number args.steps Total number of steps to do (excluding evaluation steps)
-- @tparam number args.eval_freq Number of steps between two evaluations
-- @tparam number args.eval_steps Number of evaluation steps
-- @tparam[opt] number args.save_at Next interation at which to save the experiment
-- @tparam[opt] table args.loop Results of the last iteration
-- @tparam[opt] table args.metrics Table of saved metrics
-- @tparam[opt] number args.step Current step
-- @tparam InitArgument args
function class:__init(args)
super.__init(self, args)
......@@ -85,13 +157,13 @@ function class:__init(args)
self.step = args.step or 0
end
--- Public methods
--- Public Methods
-- @section public-methods
--- Start the experiment.
--
-- This function will run @{steps} learning interactions, separated by
-- @{eval_steps} evaluation interactions each @{eval_freq} interactions.
-- This function will run `steps` learning interactions, separated by
-- `eval_steps` evaluation interactions each `eval_freq` interactions.
function class:run()
self._logger:debug("Starting experiment...")
local steps_decimal_length = math.ceil(math.log(self.steps, 10)) + 1
......@@ -148,7 +220,7 @@ function class:run()
end
--- Report loop results.
-- @return self
-- @return `self`
function class:report()
local _
_ = self.output.report.metrics and self:_metrics_report()
......@@ -160,7 +232,7 @@ function class:report()
end
--- Save the current experiment and dependencies.
-- @return self
-- @return `self`
function class:save()
if self.step >= self.save_at or self.step >= self.steps then
self.save_at = self.save_at + self.output.save_freq
......@@ -207,11 +279,11 @@ function class:save()
return self
end
--- setSteps update the total number of steps to execute.
--- Update the total number of steps to execute.
--
-- Use this function of you want to continue a previous experiment.
-- Use this function if you want to continue a previous experiment.
-- @tparam number steps New number of steps to do
-- @return self
-- @return `self`
function class:setSteps(steps)
if self.steps < steps then
for _, metric in pairs(self.metrics) do
......@@ -234,9 +306,11 @@ function class:setSteps(steps)
return self
end
--- Private methods
--- Private Methods
-- @section private-methods
--- Build and save metrics string.
-- @return `self`
function class:_metrics_report()
local total_time = self.timer:time().real
......@@ -292,6 +366,8 @@ function class:_metrics_report()
return self
end
--- Build and print a quick textual report.
-- @return `self`
function class:_text_report()
local training_time = self.loop.train.time.real
local training_rate = self.loop.train.num_it / training_time
......@@ -351,6 +427,9 @@ function class:_text_report()
return self
end
--- Save Torch components.
-- @return `self`
-- @mtodo save NN weights
function class:_torch_report()
if self.output.metrics.classification then
torch.save(
......@@ -372,8 +451,6 @@ function class:_torch_report()
self.loop.test.interactions
)
end
--- @todo TODO save NN weights
return self
end
......@@ -381,7 +458,7 @@ end
--
-- This function will plot some graph, save images or something like this about
-- the elements of the experiment.
-- @return self
-- @return `self`
function class:_graphical_report()
self._logger:debug("Saving graphical representations.")
for phase, metric in pairs(self.metrics) do
......@@ -420,7 +497,7 @@ function class:_graphical_report()
end
--- Plot evolution of reward per episode
-- @return self
-- @return `self`
function class:_plot_reward_per_ep()
-- Plot reward per episode
local abscissa = torch.range(0, self.step, self.eval_freq)/1000
......@@ -446,7 +523,7 @@ function class:_plot_reward_per_ep()
end
--- Save a graphical version of confusion matrix.
-- @return self
-- @return `self`
function class:_plot_confusion_matrix()
-- Save the confusion matrix as a PNG image
local confusion_matrix = self.loop.test.confusion_matrix:clone()
......@@ -488,6 +565,8 @@ function class:_plot_confusion_matrix()
return self
end
--- Plot evolution of the F1-Score.
-- @return `self`
function class:_plot_f1_score()
-- Plot reward per episode
local abscissa = torch.range(0, self.step, self.eval_freq)/1000
......@@ -512,6 +591,10 @@ function class:_plot_f1_score()
return self
end
--- Save some inputs.
--
-- This can be used for checks and/or post-mortem debug.
-- @return `self`
function class:_inputs_report()
self._logger:debug("Saving input frames.")
if #self.loop.train.inputs > 0 then
......@@ -535,9 +618,8 @@ function class:_inputs_report()
return self
end
--- Save a graphical representation of agent network filters
-- @return self
-- @return `self`
function class:_plot_network_filters()
if self.agent.inference and self.agent.inference.network then
local output_path = paths.concat(self.save_path, "filters")
......@@ -552,7 +634,7 @@ end
--- Dump images of the filters of the convolutionnal network.
-- @tparam string output_path Base output path for images
-- @param network Network to dump
-- @tparam nn.Container/containers.md/ network Network to dump
function class.draw_filters(output_path, network)
for i, net_module in ipairs(network:listModules()) do
if torch.type(net_module.__metatable) == "nn.Module" and
......@@ -579,10 +661,7 @@ end
--- Do some training interactions.
-- @tparam number steps Number of interactions to do
-- @treturn table A table with fields:
-- <ul>
-- <li><code>time:</code> <code>real</code>, <code>sys</code>, and <code>user</code> times</li>
-- </ul>
-- @treturn InteractionsResult Result of the interactions
function class:_train(steps)
local timer = torch.Timer()
self.environment.train:reset()
......@@ -620,13 +699,7 @@ end
--- Do some testing/evaluation interactions.
-- @tparam number steps Number of interactions to do
-- @treturn table A table with fields:
-- <ul>
-- <li><code>time:</code> <code>real</code>, <code>sys</code>, and <code>user</code> times</li>
-- <li><code>num_rewards:</code> Number of rewards obtained</li>
-- <li><code>num_episodes:</code> Number of episodes played (number of final states encountered)</li>
-- <li><code>total_reward:</code> Sum of the rewards obtained during evaluation</li>
-- </ul>
-- @treturn InteractionsResult Result of the interactions
function class:_test(steps)
local timer = torch.Timer()
self.environment.test:reset()
......@@ -663,7 +736,8 @@ end
-- interactions if necessary.
-- @tparam environment.BaseEnvironment environment Environment with witch to interact
-- @tparam number steps Number of interactions to execute
-- @treturn torch.Tensor A tensor of size {@{steps}, 2} containing terminal signal and reward for each step
-- @treturn InteractionsResult Only a subset (`num_it`, `interactions`, `confusion_matrix`,
-- `inputs`) of fields are defined
function class:_interact(environment, steps)
local num_it = 0
local state, action, reward
......@@ -720,17 +794,15 @@ function class:_interact(environment, steps)
}
end
--- @section end
--- Class methods
-- @section class-methods
--- Static Functions
-- @section static-functions
--- Compute classification metrics from a multi-classes confusion matrix.
--
-- See some explanations about metrics on this site:
-- http://blog.revolutionanalytics.com/2016/03/com_class_eval_metrics_r.html
-- @tparam torch.Tensor confusion_matrix A 2D matrix
-- @treturn table Classification metrics like: accuracy, per-class/macro/micro-precision/recall
-- [`http://blog.revolutionanalytics.com/2016/03/com_class_eval_metrics_r.html`](http://blog.revolutionanalytics.com/2016/03/com_class_eval_metrics_r.html)
-- @tparam torch.Tensor/tensor.md/ confusion_matrix A 2D matrix
-- @treturn ClassificationTable Classification metrics
function class._compute_classification_metrics(confusion_matrix)
local diag = confusion_matrix:diag()
local rowsums = confusion_matrix:sum(2):squeeze()
......
--- A set of different experiments which link an @{environment} and an @{agent}.
--- A set of different experiments which link an `environment` and an `agent`.
--
-- List of the classes in this package:
-- <ul>
-- <li>@{experiment.BaseExperiment|BaseExperiment}</li>
-- </ul>
-- @module experiment
-- @alias package
--
-- * @{experiment.BaseExperiment|BaseExperiment}
--
-- @package experiment
-- @mtodo Make `BaseExperiment` abstract and implement a `MonitoredExperiment` instead.
-- @author Alexis BRENON <alexis.brenon@imag.fr>
return require('arcades.utils.package_loader')(...)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment