Source code for chap_core.training_control

from asyncio import CancelledError


[docs] class TrainingControl: def __init__(self): self._total_samples = None self._cancelled = False self._status = "None" self._n_finished = 0
[docs] def set_total_samples(self, total_samples): self._total_samples = total_samples
[docs] def get_progress(self): return self._n_finished / self._total_samples if self._total_samples is not None else 0
[docs] def get_status(self): return self._status
[docs] def register_progress(self, n_sampled): if self._cancelled: raise CancelledError() self._n_finished += n_sampled
[docs] def set_status(self, status): if self._cancelled: raise CancelledError() self._status = status
[docs] def cancel(self): self._cancelled = True
[docs] def is_cancelled(self): return self._cancelled
[docs] class PrintingTrainingControl(TrainingControl):
[docs] def register_progress(self, n_sampled): super().register_progress(n_sampled) print(f"Progress: {self.get_progress() * 100:.2f}%")
[docs] def set_status(self, status): super().set_status(status) print(f"Status: {self.get_status()}")