Source code for chap_core.training_control
from asyncio import CancelledError
[docs]
class TrainingControl:
def __init__(self):
self._total_samples = None
self._cancelled = False
self._status = "None"
self._n_finished = 0
[docs]
def set_total_samples(self, total_samples):
self._total_samples = total_samples
[docs]
def get_progress(self):
return self._n_finished / self._total_samples if self._total_samples is not None else 0
[docs]
def get_status(self):
return self._status
[docs]
def register_progress(self, n_sampled):
if self._cancelled:
raise CancelledError()
self._n_finished += n_sampled
[docs]
def set_status(self, status):
if self._cancelled:
raise CancelledError()
self._status = status
[docs]
def cancel(self):
self._cancelled = True
[docs]
def is_cancelled(self):
return self._cancelled
[docs]
class PrintingTrainingControl(TrainingControl):
[docs]
def register_progress(self, n_sampled):
super().register_progress(n_sampled)
print(f"Progress: {self.get_progress() * 100:.2f}%")
[docs]
def set_status(self, status):
super().set_status(status)
print(f"Status: {self.get_status()}")