BaseCallback function during machine learning training

Detailed Explanation of Callback Functions in Machine Learning

In Python programming, especially in the fields of deep learning and reinforcement learning, `BaseCallback` is typically a base class used to define the interface for callback functions. A callback function is a function that is called during training to perform specific tasks, such as logging, saving the model, and adjusting the learning rate.

from stable_baselines3.common.callbacks import BaseCallback 

class CyberTrainingCallback(BaseCallback):
def __init__(self, verbose=0):
super(CyberTrainingCallback, self).__init__(verbose)
# Initialize some variables, such as those used to record information during training
self.best_mean_reward = -float('inf')
self.last_mean_reward = -float('inf')
self.check_freq = 1000 # Check every 1000 steps
self.save_path = None # Path to save the model

def _on_training_start(self) -> None:
"""
Called at the start of training.
"""
# You can initialize some variables or print some information here
print("Training is starting!")

def _on_step(self) -> bool:
"""
Called at each step of training.
"""
# Check every certain number of steps
if self.n_calls % self.check_freq == 0:
# Get the current average reward
current_mean_reward = self.locals['rewards'].mean()
print(f"Step {self.n_calls}: Mean reward = {current_mean_reward}")

# Save the model if the current average reward is higher than the previous best reward
if current_mean_reward > self.best_mean_reward:
self.best_mean_reward = current_mean_reward
if self.save_path is not None:
self.model.save(self.save_path)
print(f"Model saved to {self.save_path}")

return True # Returns True to continue training, returns False to stop training

def _on_training_end(self) -> None:
"""
Called at the end of training.
"""
# You can perform some cleanup work or print some information here
print("Training has ended!")

# Use callback
from stable_baselines3 import PPO

# Create a model
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)

# Create a callback instance
callback = CyberTrainingCallback(check_freq=1000, save_path='./best_model')

# Start training
model.learn(total_timesteps=10000, callback=callback)

Code explanation:

• Define a callback class:

• `CyberTrainingCallback` inherits from `BaseCallback`.

• In the `__init__` method, some variables are initialized, such as `best_mean_reward` to record the best average reward, `check_freq` to set the check frequency, and `save_path` to set the path to save the model.


• Callback method:

• `_on_training_start`: Called at the start of training, where you can initialize some variables or print some information.

• `_on_step`: Called at each training step, where checks and operations can be performed. For example, check the current average reward every `check_freq` steps and save the model when the reward improves.

• `_on_training_end`: Called at the end of training, where you can perform cleanup tasks or print information.


• Use callbacks:

• Create a `PPO` model.

• Create an instance of `CyberTrainingCallback` and set the check frequency and save path.

• When calling the `model.learn` method, pass the callback instance to the `callback` parameter so that the callback method will be invoked during training.

II. Commonly Used Callback Functions

Here are some commonly used callback functions and how to use them:

1. `_on_training_start()`
is called when training begins. You can initialize some variables or print some information here.

class CustomCallback(BaseCallback):
    def _on_training_start(self) -> None:
        """
        Call at the beginning of training
        """
        print("Training is starting!")
        # Initialize some variables
        self.best_mean_reward = -float('inf')
        self.save_path = './best_model'

2. `_on_rollout_start()`
is called at the start of each rollout (i.e., the sampling process of each episode or each batch).

class CustomCallback(BaseCallback):
    def _on_rollout_start(self) -> None:
        """
        Call at the beginning of each roll out
        """
        print("Rollout is starting!")

3. `_on_step()`
is called at each training step. You can perform some operations at each step here, such as logging, adjusting the learning rate, etc.

class CustomCallback(BaseCallback):
    def _on_step(self) -> bool:
        """
        Call at each training step
        """
        # Check every certain number of steps
        if self.n_calls % 1000 == 0:
            print(f"Step {self.n_calls}")
        return True  # Return True to continue training, return False to stop training

4. `_on_rollout_end()`
is called at the end of each rollout. You can perform some operations here after each rollout, such as saving the model and logging.

class CustomCallback(BaseCallback): 
def _on_rollout_end(self) -> None:
"""
Called at the end of each rollout.
"""
# Get the current average reward
current_mean_reward = self.locals['ep_info_buffer'].get_mean_reward()
print(f"Rollout ended. Mean reward: {current_mean_reward}")

# Save the model if the current average reward is higher than the previous best reward
if current_mean_reward > self.best_mean_reward:
self.best_mean_reward = current_mean_reward
if self.save_path is not None:
self.model.save(self.save_path)
print(f"Model saved to {self.save_path}")

5. `_on_training_end()`
is called when training ends. You can perform some cleanup work or print some information here.
 

class CustomCallback(BaseCallback):
    def _on_training_end(self) -> None:
        """
        Call at the end of training
        """
        print("Training has ended!")

6. `CheckPointCallback`
is used to periodically save the model during training.

from stable_baselines3.common.callbacks import CheckpointCallback
 
# Create Checkpointcallbacks instance
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./checkpoints')
 
# Using callbacks
model.learn(total_timesteps=10000, callback=checkpoint_callback)

7. `EvalCallback`
is used to periodically evaluate the model’s performance during training and save the model based on the evaluation results.

from stable_baselines3.common.callbacks import EvalCallback 
from stable_baselines3.common.env_util import make_vec_env

# Create an evaluation environment
eval_env = make_vec_env('CartPole-v1', n_envs=5)

# Create an EvalCallback instance
eval_callback = EvalCallback(eval_env, best_model_save_path='./best_model', log_path='./eval_logs', eval_freq=1000)

# callback
model.learn(total_timesteps=10000, callback=eval_callback)

8. `StopTrainingOnRewardThreshold`
is used to stop training when the model’s average reward reaches a certain threshold during training.

from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
 
# Create an instance of StopTrainingOnRewardThreshold
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
 
# callback 
model.learn(total_timesteps=10000, callback=stop_callback)

9. `EveryNTimesteps`
is used to call another callback function every N time steps.

from stable_baselines3.common.callbacks import EveryNTimesteps 

# Create a custom callback
class CustomCallback(BaseCallback):
def _on_step(self) -> bool:
print(f"Step {self.n_calls}")
return True

# Create an EveryNTimesteps instance
callback = EveryNTimesteps(n_steps=1000, callback=CustomCallback())

# callback
model.learn(total_timesteps=10000, callback=callback)

10. `CallbackList`
is used to combine multiple callback functions together, allowing you to use multiple callbacks simultaneously.

from stable_baselines3.common.callbacks import CallbackList
 
# Create multiple callbacks
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./checkpoints')
eval_callback = EvalCallback(eval_env, best_model_save_path='./best_model', log_path='./eval_logs', eval_freq=1000)
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
 
# Combine multiple callbacks together
callback = CallbackList([checkpoint_callback, eval_callback, stop_callback])
 
#  callback
model.learn(total_timesteps=10000, callback=callback)