Skip to content

SARSA

sarsa

Figure: SARSA algorithm pseudocode 1

toyrl.sarsa.default_config module-attribute

default_config = SarsaConfig(env_name='CartPole-v1', render_mode=None, solved_threshold=475.0, max_training_steps=2000000, learning_rate=0.01, log_wandb=True)

toyrl.sarsa.trainer module-attribute

toyrl.sarsa.SarsaConfig dataclass

SarsaConfig(env_name: str = 'CartPole-v1', render_mode: str | None = None, solved_threshold: float = 475.0, gamma: float = 0.999, max_training_steps: int = 500000, learning_rate: float = 0.00025, log_wandb: bool = False)

Configuration for SARSA algorithm.

env_name class-attribute instance-attribute

env_name: str = 'CartPole-v1'

render_mode class-attribute instance-attribute

render_mode: str | None = None

solved_threshold class-attribute instance-attribute

solved_threshold: float = 475.0

gamma class-attribute instance-attribute

gamma: float = 0.999

max_training_steps class-attribute instance-attribute

max_training_steps: int = 500000

The maximum number of environment steps to train for.

learning_rate class-attribute instance-attribute

learning_rate: float = 0.00025

The learning rate for the optimizer.

log_wandb class-attribute instance-attribute

log_wandb: bool = False

Whether to log the training process to Weights and Biases.

toyrl.sarsa.PolicyNet

PolicyNet(env_dim: int, action_num: int)

Bases: Module

Source code in toyrl/sarsa.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    env_dim: int,
    action_num: int,
) -> None:
    super().__init__()
    self.env_dim = env_dim
    self.action_num = action_num

    layers = [
        nn.Linear(self.env_dim, 128),
        nn.ReLU(),
        nn.Linear(128, self.action_num),
    ]
    self.model = nn.Sequential(*layers)

env_dim instance-attribute

env_dim = env_dim

action_num instance-attribute

action_num = action_num

model instance-attribute

model = Sequential(*layers)

forward

forward(x: Tensor) -> Tensor
Source code in toyrl/sarsa.py
50
51
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.model(x)

toyrl.sarsa.Experience dataclass

Experience(terminated: bool, truncated: bool, observation: Any, action: Any, reward: float, next_observation: Any = None, next_action: Any = None)

terminated instance-attribute

terminated: bool

truncated instance-attribute

truncated: bool

observation instance-attribute

observation: Any

action instance-attribute

action: Any

reward instance-attribute

reward: float

next_observation class-attribute instance-attribute

next_observation: Any = None

next_action class-attribute instance-attribute

next_action: Any = None

toyrl.sarsa.ReplayBuffer dataclass

ReplayBuffer(buffer: list[Experience] = list())

buffer class-attribute instance-attribute

buffer: list[Experience] = field(default_factory=list)

__len__

__len__() -> int
Source code in toyrl/sarsa.py
69
70
def __len__(self) -> int:
    return len(self.buffer)

add_experience

add_experience(experience: Experience) -> None
Source code in toyrl/sarsa.py
72
73
def add_experience(self, experience: Experience) -> None:
    self.buffer.append(experience)

reset

reset() -> None
Source code in toyrl/sarsa.py
75
76
def reset(self) -> None:
    self.buffer = []

sample

sample(with_next_sa: bool = True) -> list[Experience]
Source code in toyrl/sarsa.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def sample(self, with_next_sa: bool = True) -> list[Experience]:
    if with_next_sa is False:
        return self.buffer
    else:
        res = []
        for i in range(len(self.buffer) - 1):
            experience = self.buffer[i]
            next_experience = self.buffer[i + 1]
            res.append(
                Experience(
                    observation=experience.observation,
                    action=experience.action,
                    reward=experience.reward,
                    next_observation=next_experience.observation,
                    next_action=next_experience.action,
                    terminated=next_experience.terminated,
                    truncated=next_experience.truncated,
                )
            )
        return res

toyrl.sarsa.Agent

Agent(policy_net: PolicyNet, optimizer: Optimizer)
Source code in toyrl/sarsa.py
101
102
103
104
105
def __init__(self, policy_net: PolicyNet, optimizer: torch.optim.Optimizer) -> None:
    self._policy_net = policy_net
    self._optimizer = optimizer
    self._replay_buffer = ReplayBuffer()
    self._action_num = policy_net.action_num

onpolicy_reset

onpolicy_reset() -> None
Source code in toyrl/sarsa.py
107
108
def onpolicy_reset(self) -> None:
    self._replay_buffer.reset()

add_experience

add_experience(experience: Experience) -> None
Source code in toyrl/sarsa.py
110
111
def add_experience(self, experience: Experience) -> None:
    self._replay_buffer.add_experience(experience)

act

act(observation: floating, epsilon: float) -> int
Source code in toyrl/sarsa.py
113
114
115
116
117
118
119
120
121
def act(self, observation: np.floating, epsilon: float) -> int:
    if np.random.rand() < epsilon:
        action = np.random.randint(self._action_num)
        return action
    x = torch.from_numpy(observation.astype(np.float32))
    with torch.no_grad():
        logits = self._policy_net(x)
    action = int(torch.argmax(logits).item())
    return action

policy_update

policy_update(gamma: float) -> float
Source code in toyrl/sarsa.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def policy_update(self, gamma: float) -> float:
    experiences = self._replay_buffer.sample()

    observations = torch.tensor([experience.observation for experience in experiences])
    actions = torch.tensor([experience.action for experience in experiences], dtype=torch.float32)
    next_observations = torch.tensor([experience.next_observation for experience in experiences])
    next_actions = torch.tensor([experience.next_action for experience in experiences])
    rewards = torch.tensor([experience.reward for experience in experiences])
    terminated = torch.tensor(
        [experience.terminated for experience in experiences],
        dtype=torch.float32,
    )

    # q preds
    action_q_preds = self._policy_net(observations).gather(1, actions.long().unsqueeze(1)).squeeze(1)
    with torch.no_grad():
        next_action_q_preds = self._policy_net(next_observations).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        q_targets = rewards + gamma * (1 - terminated) * next_action_q_preds
    loss = torch.nn.functional.mse_loss(action_q_preds, q_targets)
    # update
    self._optimizer.zero_grad()
    loss.backward()
    # clip grad
    torch.nn.utils.clip_grad_norm_(self._policy_net.parameters(), max_norm=1.0)
    self._optimizer.step()
    return loss.item()

toyrl.sarsa.SarsaTrainer

SarsaTrainer(config: SarsaConfig)
Source code in toyrl/sarsa.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def __init__(self, config: SarsaConfig) -> None:
    self.config = config
    self.env = self._make_env()
    if isinstance(self.env.action_space, gym.spaces.Discrete) is False:
        raise ValueError("Only discrete action space is supported.")
    env_dim = self.env.observation_space.shape[0]
    action_num = self.env.action_space.n
    policy_net = PolicyNet(env_dim=env_dim, action_num=action_num)
    optimizer = optim.Adam(policy_net.parameters(), lr=config.learning_rate)
    self.agent = Agent(policy_net=policy_net, optimizer=optimizer)

    self.gamma = config.gamma
    self.solved_threshold = config.solved_threshold
    if config.log_wandb:
        wandb.init(
            # set the wandb project where this run will be logged
            project="SARSA",
            name=f"[{config.env_name}],lr={config.learning_rate}",
            # track hyperparameters and run metadata
            config=asdict(config),
        )

config instance-attribute

config = config

env instance-attribute

env = _make_env()

agent instance-attribute

agent = Agent(policy_net=policy_net, optimizer=optimizer)

gamma instance-attribute

gamma = gamma

solved_threshold instance-attribute

solved_threshold = solved_threshold

train

train() -> None
Source code in toyrl/sarsa.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def train(self) -> None:
    epsilon = 1.0
    global_step = 0

    observation, _ = self.env.reset()
    while global_step < self.config.max_training_steps:
        global_step += 1
        epsilon = max(0.05, epsilon * 0.9999)

        action = self.agent.act(observation, epsilon)
        next_observation, reward, terminated, truncated, info = self.env.step(action)
        experience = Experience(
            observation=observation,
            action=action,
            reward=float(reward),
            next_observation=next_observation,
            terminated=terminated,
            truncated=truncated,
        )
        self.agent.add_experience(experience)
        observation = next_observation
        if self.env.render_mode is not None:
            self.env.render()

        if terminated or truncated:
            if info and "episode" in info:
                episode_reward = info["episode"]["r"]
                loss = self.agent.policy_update(gamma=self.gamma)
                self.agent.onpolicy_reset()
                print(
                    f"global_step={global_step}, epsilon={epsilon}, episodic_return={episode_reward}, loss={loss}"
                )
                if self.config.log_wandb:
                    wandb.log(
                        {
                            "global_step": global_step,
                            "episode_reward": episode_reward,
                            "loss": loss,
                        }
                    )

  1. L. Graesser and W. L. Keng, Foundations of deep reinforcement learning: Theory and practice in python. Addison-Wesley Professional, 2019.