training_node.training_node#
The TrainingNode implements a base class with an algorithm agnostic training loop.
In order to implement specific algorithms, training nodes can customize the hooks provided in
run(). The general logic is to broadcast the initial model, receive samples,
validate them, check if a training step should be taken, and broadcast the new model.
Each training node also has a heartbeat service that determines how many clients are currently active. The heartbeat service can also be customized to require nodes (or a subset of them) to be continuously connected.
- class soulsai.distributed.server.training_node.training_node.TrainingNode(config: SimpleNamespace)#
Algorithm agnostic base class for training nodes.
- run()#
Run the training node.
Derived classes modify the provided hooks in the loop to implement different learning algorithms. The main loop receives samples sent from worker nodes via Redis, verifies that the samples can be used, appends them to a buffer, checks if the training step condition is met, updates the agent and uploads the new parameters to Redis.
Additionally, the training node runs a heartbeat service to detect node disconnects. This is primarily important for synchronous algorithms that do not support the dynamic addition and removal of worker nodes.
- save_config(path: Path)#
Save the training configuration to a file.
- Parameters:
path – Path to the configuration file.
- load_config(path: Path)#
Load the training configuration from file.
- Parameters:
path – Path to the configuration file.
- monitor_timing(prom_timer: Gauge)#
Monitor the execution time of a code block and store it in the Prometheus Gauge.
Note
Only activates if Prometheus is enabled in the training config.
- Parameters:
prom_timer – A Prometheus Gauge object that is updated with the execution time
- shutdown(_: Any)#
Shut down the training node.
- abstract checkpoint(path: Path, options: dict = {})#
Create a training checkpoint.
- Parameters:
path – Path to the save folder.
options – Additional options dictionary to customize checkpointing.
- abstract load_checkpoint(path: Path)#
Load a training checkpoint from the folder.
- Parameters:
path – Path to the save folder.