Skip to content

envs.drone_race

lsy_drone_racing.envs.drone_race

Single drone racing environments.

Classes

DroneRaceEnv(freq, sim_config, track, sensor_range=0.5, control_mode='state', disturbances=None, randomizations=None, seed=None, max_episode_steps=1500, device='cpu')

Bases: RaceCoreEnv, Env

Single-agent drone racing environment.

Initialize the single-agent drone racing environment.

Parameters:

Name Type Description Default
freq int

Environment step frequency.

required
sim_config ConfigDict

Simulation configuration.

required
track ConfigDict

Track configuration.

required
sensor_range float

Sensor range.

0.5
control_mode Literal['state', 'attitude']

Control mode for the drones. See build_action_space for details.

'state'
disturbances ConfigDict | None

Disturbance configuration.

None
randomizations ConfigDict | None

Randomization configuration.

None
seed int | None

None / -1 for a generated seed or the random seed directly.

None
max_episode_steps int

Maximum number of steps per episode.

1500
device Literal['cpu', 'gpu']

Device used for the environment and the simulation.

'cpu'
Source code in lsy_drone_racing/envs/drone_race.py
def __init__(
    self,
    freq: int,
    sim_config: ConfigDict,
    track: ConfigDict,
    sensor_range: float = 0.5,
    control_mode: Literal["state", "attitude"] = "state",
    disturbances: ConfigDict | None = None,
    randomizations: ConfigDict | None = None,
    seed: int | None = None,
    max_episode_steps: int = 1500,
    device: Literal["cpu", "gpu"] = "cpu",
):
    """Initialize the single-agent drone racing environment.

    Args:
        freq: Environment step frequency.
        sim_config: Simulation configuration.
        track: Track configuration.
        sensor_range: Sensor range.
        control_mode: Control mode for the drones. See `build_action_space` for details.
        disturbances: Disturbance configuration.
        randomizations: Randomization configuration.
        seed: None / -1 for a generated seed or the random seed directly.
        max_episode_steps: Maximum number of steps per episode.
        device: Device used for the environment and the simulation.
    """
    super().__init__(
        n_envs=1,
        n_drones=1,
        freq=freq,
        sim_config=sim_config,
        track=track,
        sensor_range=sensor_range,
        control_mode=control_mode,
        disturbances=disturbances,
        randomizations=randomizations,
        seed=seed,
        max_episode_steps=max_episode_steps,
        device=device,
    )
    self.action_space = build_action_space(control_mode, sim_config.drone_model)
    n_gates, n_obstacles = len(track.gates), len(track.obstacles)
    self.observation_space = build_observation_space(n_gates, n_obstacles)
    self.settings = self.settings.replace(autoreset=False)
    self._step = self.build_step_fn()  # Apply new settings to capture autoreset effect
Methods:
reset(seed=None, options=None)

Reset the environment.

Parameters:

Name Type Description Default
seed int | None

Random seed.

None
options dict | None

Additional reset options. Not used.

None

Returns:

Type Description
tuple[dict, dict]

The initial observation and info.

Source code in lsy_drone_racing/envs/drone_race.py
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
    """Reset the environment.

    Args:
        seed: Random seed.
        options: Additional reset options. Not used.

    Returns:
        The initial observation and info.
    """
    Env.reset(self, seed=seed, options=options)
    self.data, (obs, info) = self._reset(self.data, seed=seed)
    obs = {k: v[0, 0] for k, v in obs.items()}
    info = {k: v[0, 0] for k, v in info.items()}
    return obs, info
step(action)

Step the environment.

Parameters:

Name Type Description Default
action Array

Action for the drone.

required

Returns:

Type Description
tuple[dict, float, bool, bool, dict]

Observation, reward, terminated, truncated, and info.

Source code in lsy_drone_racing/envs/drone_race.py
def step(self, action: Array) -> tuple[dict, float, bool, bool, dict]:
    """Step the environment.

    Args:
        action: Action for the drone.

    Returns:
        Observation, reward, terminated, truncated, and info.
    """
    self.data, (obs, reward, terminated, truncated, info) = self._step(self.data, action)
    obs = {k: v[0, 0] for k, v in obs.items()}
    info = {k: v[0, 0] for k, v in info.items()}
    return obs, float(reward[0, 0]), bool(terminated[0, 0]), bool(truncated[0, 0]), info

VecDroneRaceEnv(num_envs, freq, sim_config, track, sensor_range=0.5, control_mode='state', disturbances=None, randomizations=None, seed=1337, max_episode_steps=1500, device='cpu')

Bases: RaceCoreEnv, VectorEnv

Vectorized single-agent drone racing environment.

Initialize the vectorized single-agent drone racing environment.

Parameters:

Name Type Description Default
num_envs int

Number of worlds in the vectorized environment.

required
freq int

Environment step frequency.

required
sim_config ConfigDict

Simulation configuration.

required
track ConfigDict

Track configuration.

required
sensor_range float

Sensor range.

0.5
control_mode Literal['state', 'attitude']

Control mode for the drones. See build_action_space for details.

'state'
disturbances ConfigDict | None

Disturbance configuration.

None
randomizations ConfigDict | None

Randomization configuration.

None
seed int

Random seed.

1337
max_episode_steps int

Maximum number of steps per episode.

1500
device Literal['cpu', 'gpu']

Device used for the environment and the simulation.

'cpu'
Source code in lsy_drone_racing/envs/drone_race.py
def __init__(
    self,
    num_envs: int,
    freq: int,
    sim_config: ConfigDict,
    track: ConfigDict,
    sensor_range: float = 0.5,
    control_mode: Literal["state", "attitude"] = "state",
    disturbances: ConfigDict | None = None,
    randomizations: ConfigDict | None = None,
    seed: int = 1337,
    max_episode_steps: int = 1500,
    device: Literal["cpu", "gpu"] = "cpu",
):
    """Initialize the vectorized single-agent drone racing environment.

    Args:
        num_envs: Number of worlds in the vectorized environment.
        freq: Environment step frequency.
        sim_config: Simulation configuration.
        track: Track configuration.
        sensor_range: Sensor range.
        control_mode: Control mode for the drones. See `build_action_space` for details.
        disturbances: Disturbance configuration.
        randomizations: Randomization configuration.
        seed: Random seed.
        max_episode_steps: Maximum number of steps per episode.
        device: Device used for the environment and the simulation.
    """
    super().__init__(
        n_envs=num_envs,
        n_drones=1,
        freq=freq,
        sim_config=sim_config,
        track=track,
        sensor_range=sensor_range,
        control_mode=control_mode,
        disturbances=disturbances,
        randomizations=randomizations,
        seed=seed,
        max_episode_steps=max_episode_steps,
        device=device,
    )
    self.num_envs = num_envs
    self.single_action_space = build_action_space(control_mode, sim_config.drone_model)
    self.action_space = batch_space(self.single_action_space, num_envs)
    n_gates, n_obstacles = len(track.gates), len(track.obstacles)
    self.single_observation_space = build_observation_space(n_gates, n_obstacles)
    self.observation_space = batch_space(self.single_observation_space, num_envs)
Methods:
reset(seed=None, options=None)

Reset the environment in all worlds.

Parameters:

Name Type Description Default
seed int | None

Random seed.

None
options dict | None

Additional reset options. Not used.

None

Returns:

Type Description
tuple[dict, dict]

The initial observation and info.

Source code in lsy_drone_racing/envs/drone_race.py
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
    """Reset the environment in all worlds.

    Args:
        seed: Random seed.
        options: Additional reset options. Not used.

    Returns:
        The initial observation and info.
    """
    VectorEnv.reset(self, seed=seed, options=options)
    self.data, (obs, info) = self._reset(self.data, seed=seed)
    obs = {k: v[:, 0] for k, v in obs.items()}
    info = {k: v[:, 0] for k, v in info.items()}
    return obs, info
step(action)

Step the environment in all worlds.

Parameters:

Name Type Description Default
action Array

Action for all worlds, i.e., a batch of (n_envs, action_dim) arrays.

required

Returns:

Type Description
tuple[dict, Array, Array, Array, dict]

Observation, reward, terminated, truncated, and info.

Source code in lsy_drone_racing/envs/drone_race.py
def step(self, action: Array) -> tuple[dict, Array, Array, Array, dict]:
    """Step the environment in all worlds.

    Args:
        action: Action for all worlds, i.e., a batch of (n_envs, action_dim) arrays.

    Returns:
        Observation, reward, terminated, truncated, and info.
    """
    self.data, (obs, reward, terminated, truncated, info) = self._step(self.data, action)
    obs = {k: v[:, 0] for k, v in obs.items()}
    info = {k: v[:, 0] for k, v in info.items()}
    return obs, reward[:, 0], terminated[:, 0], truncated[:, 0], info

Functions:

lsy_drone_racing.envs.multi_drone_race

Multi-agent drone racing environments.

Classes

MultiDroneRaceEnv(freq, sim_config, track, sensor_range=0.5, control_mode='state', disturbances=None, randomizations=None, seed=None, max_episode_steps=1500, device='cpu')

Bases: RaceCoreEnv, Env

Multi-agent drone racing environment.

This environment enables multiple agents to simultaneously compete with each other on the same track.

Initialize the multi-agent drone racing environment.

Parameters:

Name Type Description Default
freq int

Environment step frequency.

required
sim_config ConfigDict

Simulation configuration.

required
track ConfigDict

Track configuration.

required
sensor_range float

Sensor range.

0.5
control_mode Literal['state', 'attitude']

Control mode for the drones. See build_action_space for details.

'state'
disturbances ConfigDict | None

Disturbance configuration.

None
randomizations ConfigDict | None

Randomization configuration.

None
seed int | None

None / -1 for a generated seed or the random seed directly.

None
max_episode_steps int

Maximum number of steps per episode.

1500
device Literal['cpu', 'gpu']

Device used for the environment and the simulation.

'cpu'
Source code in lsy_drone_racing/envs/multi_drone_race.py
def __init__(
    self,
    freq: int,
    sim_config: ConfigDict,
    track: ConfigDict,
    sensor_range: float = 0.5,
    control_mode: Literal["state", "attitude"] = "state",
    disturbances: ConfigDict | None = None,
    randomizations: ConfigDict | None = None,
    seed: int | None = None,
    max_episode_steps: int = 1500,
    device: Literal["cpu", "gpu"] = "cpu",
):
    """Initialize the multi-agent drone racing environment.

    Args:
        freq: Environment step frequency.
        sim_config: Simulation configuration.
        track: Track configuration.
        sensor_range: Sensor range.
        control_mode: Control mode for the drones. See `build_action_space` for details.
        disturbances: Disturbance configuration.
        randomizations: Randomization configuration.
        seed: None / -1 for a generated seed or the random seed directly.
        max_episode_steps: Maximum number of steps per episode.
        device: Device used for the environment and the simulation.
    """
    n_gates, n_obstacles, n_drones = len(track.gates), len(track.obstacles), len(track.drones)
    super().__init__(
        n_envs=1,
        n_drones=n_drones,
        freq=freq,
        sim_config=sim_config,
        sensor_range=sensor_range,
        track=track,
        control_mode=control_mode,
        disturbances=disturbances,
        randomizations=randomizations,
        seed=seed,
        max_episode_steps=max_episode_steps,
        device=device,
    )
    self.action_space = batch_space(
        build_action_space(control_mode, sim_config.drone_model), n_drones
    )
    self.observation_space = batch_space(
        build_observation_space(n_gates, n_obstacles), n_drones
    )
    self.settings = self.settings.replace(autoreset=False)
    self._step = self.build_step_fn()  # Apply new settings to capture autoreset effect
Methods:
reset(seed=None, options=None)

Reset the environment for all drones.

Parameters:

Name Type Description Default
seed int | None

Random seed.

None
options dict | None

Additional reset options. Not used.

None

Returns:

Type Description
tuple[dict, dict]

Observation and info for all drones.

Source code in lsy_drone_racing/envs/multi_drone_race.py
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
    """Reset the environment for all drones.

    Args:
        seed: Random seed.
        options: Additional reset options. Not used.

    Returns:
        Observation and info for all drones.
    """
    Env.reset(self, seed=seed, options=options)
    self.data, (obs, info) = self._reset(self.data, seed=seed)
    obs = {k: v[0] for k, v in obs.items()}
    info = {k: v[0] for k, v in info.items()}
    return obs, info
step(action)

Step the environment for all drones.

Parameters:

Name Type Description Default
action Array

Action for all drones, i.e., a batch of (n_drones, action_dim) arrays.

required

Returns:

Type Description
tuple[dict, Array, Array, Array, dict]

Observation, reward, terminated, truncated, and info for all drones.

Source code in lsy_drone_racing/envs/multi_drone_race.py
def step(self, action: Array) -> tuple[dict, Array, Array, Array, dict]:
    """Step the environment for all drones.

    Args:
        action: Action for all drones, i.e., a batch of (n_drones, action_dim) arrays.

    Returns:
        Observation, reward, terminated, truncated, and info for all drones.
    """
    self.data, (obs, reward, terminated, truncated, info) = self._step(self.data, action)
    obs = {k: v[0] for k, v in obs.items()}
    info = {k: v[0] for k, v in info.items()}
    # TODO: Fix by moving towards pettingzoo API
    # https://pettingzoo.farama.org/api/parallel/
    return obs, reward[0, 0], terminated[0].all(), truncated[0].all(), info

VecMultiDroneRaceEnv(num_envs, freq, sim_config, track, sensor_range=0.5, control_mode='state', disturbances=None, randomizations=None, seed=1337, max_episode_steps=1500, device='cpu')

Bases: RaceCoreEnv, VectorEnv

Vectorized multi-agent drone racing environment.

This environment enables vectorized training of multi-agent drone racing agents.

Vectorized multi-agent drone racing environment.

Parameters:

Name Type Description Default
num_envs int

Number of worlds in the vectorized environment.

required
freq int

Environment step frequency.

required
sim_config ConfigDict

Simulation configuration.

required
track ConfigDict

Track configuration.

required
sensor_range float

Sensor range.

0.5
control_mode Literal['state', 'attitude']

Control mode for the drones. See build_action_space for details.

'state'
disturbances ConfigDict | None

Disturbance configuration.

None
randomizations ConfigDict | None

Randomization configuration.

None
seed int

Random seed.

1337
max_episode_steps int

Maximum number of steps per episode.

1500
device Literal['cpu', 'gpu']

Device used for the environment and the simulation.

'cpu'
Source code in lsy_drone_racing/envs/multi_drone_race.py
def __init__(
    self,
    num_envs: int,
    freq: int,
    sim_config: ConfigDict,
    track: ConfigDict,
    sensor_range: float = 0.5,
    control_mode: Literal["state", "attitude"] = "state",
    disturbances: ConfigDict | None = None,
    randomizations: ConfigDict | None = None,
    seed: int = 1337,
    max_episode_steps: int = 1500,
    device: Literal["cpu", "gpu"] = "cpu",
):
    """Vectorized multi-agent drone racing environment.

    Args:
        num_envs: Number of worlds in the vectorized environment.
        freq: Environment step frequency.
        sim_config: Simulation configuration.
        track: Track configuration.
        sensor_range: Sensor range.
        control_mode: Control mode for the drones. See `build_action_space` for details.
        disturbances: Disturbance configuration.
        randomizations: Randomization configuration.
        seed: Random seed.
        max_episode_steps: Maximum number of steps per episode.
        device: Device used for the environment and the simulation.
    """
    n_gates, n_obstacles, n_drones = len(track.gates), len(track.obstacles), len(track.drones)
    super().__init__(
        n_envs=num_envs,
        n_drones=n_drones,
        freq=freq,
        sim_config=sim_config,
        sensor_range=sensor_range,
        track=track,
        control_mode=control_mode,
        disturbances=disturbances,
        randomizations=randomizations,
        seed=seed,
        max_episode_steps=max_episode_steps,
        device=device,
    )
    self.num_envs = num_envs
    self.single_action_space = batch_space(
        build_action_space(control_mode, sim_config.drone_model), n_drones
    )
    self.action_space = batch_space(batch_space(self.single_action_space), num_envs)
    self.single_observation_space = batch_space(
        build_observation_space(n_gates, n_obstacles), n_drones
    )
    self.observation_space = batch_space(self.single_observation_space, num_envs)
Methods:
reset(seed=None, options=None)

Reset the environment for all drones.

Parameters:

Name Type Description Default
seed int | None

Random seed.

None
options dict | None

Additional reset options. Not used.

None

Returns:

Type Description
tuple[dict, dict]

Observation and info for all drones.

Source code in lsy_drone_racing/envs/multi_drone_race.py
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
    """Reset the environment for all drones.

    Args:
        seed: Random seed.
        options: Additional reset options. Not used.

    Returns:
        Observation and info for all drones.
    """
    VectorEnv.reset(self, seed=seed, options=options)
    self.data, (obs, info) = self._reset(self.data, seed=seed)
    return obs, info
step(action)

Step the environment for all drones.

Parameters:

Name Type Description Default
action Array

Action for all drones, i.e., a batch of (n_drones, action_dim) arrays.

required
Source code in lsy_drone_racing/envs/multi_drone_race.py
def step(self, action: Array) -> tuple[dict, Array, Array, Array, dict]:
    """Step the environment for all drones.

    Args:
        action: Action for all drones, i.e., a batch of (n_drones, action_dim) arrays.
    """
    self.data, (obs, reward, terminated, truncated, info) = self._step(self.data, action)
    return obs, reward, terminated, truncated, info

Functions:

lsy_drone_racing.envs.race_core

Core environment for drone racing simulations.

This module provides the shared logic for simulating drone racing environments. It defines a core environment class that wraps our drone simulation, drone control, gate tracking, and collision detection. The module serves as the base for both single-drone and multi-drone racing environments.

The environment is designed to be configurable, supporting:

  • Different control modes (state or attitude)
  • Customizable tracks with gates and obstacles
  • Various randomization options for robust policy training
  • Disturbance modeling for realistic flight conditions
  • Vectorized execution for parallel training

This module is primarily used as a base for the higher-level environments in drone_race and multi_drone_race, which provide Gymnasium-compatible interfaces for reinforcement learning, MPC and other control techniques.

Classes

EnvData

Struct holding the data of all auxiliary variables for the environment.

This dataclass stores the dynamic and static state of the environment that is not directly part of the physics simulation. It includes information about gate progress, drone status, and environment boundaries. Static variables are initialized once and do not change during the episode.

Attributes:

Name Type Description
target_gate Array

Current target gate index for each drone in each environment

gates_visited Array

Boolean flags indicating which gates have been visited by each drone

obstacles_visited Array

Boolean flags indicating which obstacles have been detected

last_drone_pos Array

Previous positions of drones, used for gate passing detection

marked_for_reset Array

Flags indicating which environments need to be reset

disabled_drones Array

Flags indicating which drones have crashed or are otherwise disabled

contact_masks Array

Masks for contact detection between drones and objects

pos_limit_low Array

Lower position limits for the environment

pos_limit_high Array

Upper position limits for the environment

gate_mj_ids Array

MuJoCo IDs for the gates

obstacle_mj_ids Array

MuJoCo IDs for the obstacles

max_episode_steps Array

Maximum number of steps per episode

sensor_range Array

Range at which drones can detect gates and obstacles

Methods:
create(n_gates, n_obstacles, contact_masks, max_episode_steps, sensor_range, pos_limit_low, pos_limit_high, nominal_gates_pos, nominal_gates_quat, nominal_obstacles_pos, sim_data, device) staticmethod

Create a new environment data struct with default values.

Source code in lsy_drone_racing/envs/race_core.py
@staticmethod
def create(
    n_gates: int,
    n_obstacles: int,
    contact_masks: Array,
    max_episode_steps: int,
    sensor_range: float,
    pos_limit_low: Array,
    pos_limit_high: Array,
    nominal_gates_pos: Array,
    nominal_gates_quat: Array,
    nominal_obstacles_pos: Array,
    sim_data: SimData,
    device: Device,
) -> EnvData:
    """Create a new environment data struct with default values."""
    n_envs = sim_data.core.n_worlds
    n_drones = sim_data.core.n_drones
    tiled_gates_pos = jax.device_put(
        jp.tile(nominal_gates_pos[None, ...], (n_envs, 1, 1)), device
    )
    tiled_gates_quat = jax.device_put(
        jp.tile(nominal_gates_quat[None, ...], (n_envs, 1, 1)), device
    )
    tiled_obstacles_pos = jax.device_put(
        jp.tile(nominal_obstacles_pos[None, ...], (n_envs, 1, 1)), device
    )
    return EnvData(
        target_gate=jp.zeros((n_envs, n_drones), dtype=int, device=device),
        gates_visited=jp.zeros((n_envs, n_drones, n_gates), dtype=bool, device=device),
        obstacles_visited=jp.zeros((n_envs, n_drones, n_obstacles), dtype=bool, device=device),
        last_drone_pos=jp.zeros((n_envs, n_drones, 3), dtype=np.float32, device=device),
        marked_for_reset=jp.zeros(n_envs, dtype=bool, device=device),
        disabled_drones=jp.zeros((n_envs, n_drones), dtype=bool, device=device),
        contact_masks=jp.array(contact_masks, dtype=bool, device=device),
        steps=jp.zeros(n_envs, dtype=int, device=device),
        takeoff_pos=jp.zeros((n_envs, n_drones, 3), dtype=np.float32, device=device),
        pos_limit_low=jp.array(pos_limit_low, dtype=np.float32, device=device),
        pos_limit_high=jp.array(pos_limit_high, dtype=np.float32, device=device),
        max_episode_steps=jp.array([max_episode_steps], dtype=int, device=device),
        gates_pos=tiled_gates_pos,
        gates_quat=tiled_gates_quat,
        obstacles_pos=tiled_obstacles_pos,
        nominal_gates_pos=tiled_gates_pos,
        nominal_gates_quat=tiled_gates_quat,
        nominal_obstacles_pos=tiled_obstacles_pos,
        sim_data=sim_data,
        sensor_range=jp.array([sensor_range], dtype=jp.float32, device=device),
    )

EnvSettings

Struct holding all configuration settings for the environment.

Methods:
create(freq, max_episode_steps, pos_limit_low, pos_limit_high, camera, cam_config, disturbances, randomizations, device, autoreset=True) staticmethod

Create a new environment settings struct from a configuration dictionary.

Source code in lsy_drone_racing/envs/race_core.py
@staticmethod
def create(
    freq: int,
    max_episode_steps: int,
    pos_limit_low: Array,
    pos_limit_high: Array,
    camera: int | str,
    cam_config: dict[str, int | list[float]],
    disturbances: dict[str, Callable[[Array, Array, Array], Array]],
    randomizations: dict[str, Callable[[Array, Array, Array], Array]],
    device: Device,
    autoreset: bool = True,
) -> EnvSettings:
    """Create a new environment settings struct from a configuration dictionary."""
    return EnvSettings(
        freq=freq,
        max_episode_steps=max_episode_steps,
        pos_limit_low=jp.array(pos_limit_low, dtype=jp.float32, device=device),
        pos_limit_high=jp.array(pos_limit_high, dtype=jp.float32, device=device),
        camera=camera,
        cam_config=cam_config,
        disturbances=disturbances,
        randomizations=randomizations,
        device=device,
        autoreset=autoreset,
    )

RaceCoreEnv(n_envs, n_drones, freq, sim_config, sensor_range, track, control_mode='state', disturbances=None, randomizations=None, seed=None, max_episode_steps=1500, device='cpu')

The core environment for drone racing simulations.

This environment simulates a drone racing scenario where a single drone navigates through a series of gates in a predefined track. It supports various configuration options for randomization, disturbances, and physics models.

The environment provides:

  • A customizable track with gates and obstacles
  • Configurable simulation and control frequencies
  • Support for different physics models (e.g., identified dynamics, analytical dynamics)
  • Randomization of drone properties and initial conditions
  • Disturbance modeling for realistic flight conditions
  • Symbolic expressions for advanced control techniques (optional)

The environment tracks the drone's progress through the gates and provides termination conditions based on gate passages and collisions.

The observation space is a dictionary with the following keys:

  • pos: Drone position
  • quat: Drone orientation as a quaternion (x, y, z, w)
  • vel: Drone linear velocity
  • ang_vel: Drone angular velocity
  • gates_pos: Positions of the gates
  • gates_quat: Orientations of the gates
  • gates_visited: Flags indicating if the drone already was/ is in the sensor range of the gates and the true position is known
  • obstacles_pos: Positions of the obstacles
  • obstacles_visited: Flags indicating if the drone already was/ is in the sensor range of the obstacles and the true position is known
  • target_gate: The current target gate index

The action space consists of a desired full-state command [x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate] that is tracked by the drone's low-level controller, or a desired collective thrust and attitude command [collective thrust, roll, pitch, yaw].

Initialize the DroneRacingEnv.

Parameters:

Name Type Description Default
n_envs int

Number of worlds in the vectorized environment.

required
n_drones int

Number of drones.

required
freq int

Environment step frequency.

required
sim_config ConfigDict

Configuration dictionary for the simulation.

required
sensor_range float

Sensor range for gate and obstacle detection.

required
control_mode Literal['state', 'attitude']

Control mode for the drones. See build_action_space for details.

'state'
track ConfigDict

Track configuration.

required
disturbances ConfigDict | None

Disturbance configuration.

None
randomizations ConfigDict | None

Randomization configuration.

None
seed int | None

None / -1 for a generated seed or the random seed directly.

None
max_episode_steps int

Maximum number of steps per episode. Needs to be tracked manually for vectorized environments.

1500
device Literal['cpu', 'gpu']

Device used for the environment and the simulation.

'cpu'
Source code in lsy_drone_racing/envs/race_core.py
def __init__(
    self,
    n_envs: int,
    n_drones: int,
    freq: int,
    sim_config: ConfigDict,
    sensor_range: float,
    track: ConfigDict,
    control_mode: Literal["state", "attitude"] = "state",
    disturbances: ConfigDict | None = None,
    randomizations: ConfigDict | None = None,
    seed: int | None = None,
    max_episode_steps: int = 1500,
    device: Literal["cpu", "gpu"] = "cpu",
):
    """Initialize the DroneRacingEnv.

    Args:
        n_envs: Number of worlds in the vectorized environment.
        n_drones: Number of drones.
        freq: Environment step frequency.
        sim_config: Configuration dictionary for the simulation.
        sensor_range: Sensor range for gate and obstacle detection.
        control_mode: Control mode for the drones. See `build_action_space` for details.
        track: Track configuration.
        disturbances: Disturbance configuration.
        randomizations: Randomization configuration.
        seed: None / -1 for a generated seed or the random seed directly.
        max_episode_steps: Maximum number of steps per episode. Needs to be tracked manually for
            vectorized environments.
        device: Device used for the environment and the simulation.
    """
    super().__init__()
    # 1) Sanitize args
    if sim_config.freq % freq != 0:
        raise ValueError(f"({sim_config.freq=}) is no multiple of ({freq=})")
    assert seed is None or isinstance(seed, int), f"Unexpected seed type: {type(seed)}"

    # 2) Set seeds for reproducibility
    # TOML does not support None values, so we use -1 to indicate that a random seed should be
    # generated. This is equivalent to seed=None, but allows us to use TOML for configuration.
    seed = None if seed == -1 else seed
    # JAX must have an integer key, so we generate one from numpy. If the key was None / -1, we
    # get a randomly-seeded numpy rng, and hence a random JAX rng key. If a seed was given, we
    # get a reproducible numpy rng, which always generates the same JAX key for the same seed.
    rng_key = int(np.random.default_rng(seed).integers(0, 2**32 - 1))

    # 3) Create the simulation
    self.sim = Sim(
        n_worlds=n_envs,
        n_drones=n_drones,
        physics=sim_config.physics,
        drone_model=sim_config.drone_model,
        control=control_mode,
        freq=sim_config.freq,
        state_freq=freq,
        attitude_freq=sim_config.attitude_freq,
        rng_key=rng_key,
        device=device,
        xml_path=Path(p) if (p := getattr(sim_config, "xml_path", None)) else None,
    )
    self._load_track_into_sim(track)
    use_box_collision(self.sim, True)

    # 4) Create the environment data and settings
    self.track = track
    gates, obstacles, drones = load_track(track)
    n_gates, n_obstacles = len(track.gates), len(track.obstacles)
    contact_masks = _load_contact_masks(self.sim)
    specs = {} if disturbances is None else disturbances
    disturbances = {mode: rng_spec2fn(spec) for mode, spec in specs.items()}
    specs = {} if randomizations is None else randomizations
    randomizations = {mode: rng_spec2fn(spec) for mode, spec in specs.items()}
    self.settings = EnvSettings.create(
        freq=freq,
        max_episode_steps=max_episode_steps,
        pos_limit_low=[-3, -3, 0.0],
        pos_limit_high=[3, 3, 2.5],
        camera=sim_config.camera,
        cam_config=sim_config.cam_config[0],
        disturbances=disturbances,
        randomizations=randomizations,
        device=jax.devices(device)[0],
        autoreset=True,
    )
    self.data = EnvData.create(
        n_gates=n_gates,
        n_obstacles=n_obstacles,
        contact_masks=contact_masks,
        max_episode_steps=max_episode_steps,
        sensor_range=sensor_range,
        pos_limit_low=[-3, -3, 0.0],
        pos_limit_high=[3, 3, 2.5],
        nominal_gates_pos=gates.nominal_pos,
        nominal_gates_quat=gates.nominal_quat,
        nominal_obstacles_pos=obstacles.nominal_pos,
        sim_data=self.sim.data,
        device=self.settings.device,
    )

    # 5) Generate functions
    self._setup_sim(randomizations, drones)
    self._reset = self.build_reset_fn()
    self._step = self.build_step_fn()
    self._render_sync = self.build_render_sync_fn()
Attributes
drone_mass property

The mass of the drones in the environment.

mocap_ids property

The MuJoCo mocap IDs for the gates and obstacles.

Methods:
build_apply_action_fn()

Build a function that applies the action to the simulation.

Source code in lsy_drone_racing/envs/race_core.py
def build_apply_action_fn(self) -> Callable[[Array, EnvData, EnvSettings], EnvData]:
    """Build a function that applies the action to the simulation."""
    action_space = build_action_space(self.sim.control, self.sim.drone_model)
    if self.sim.control == "state":
        ctrl_fn = F.state_control
    elif self.sim.control == "attitude":
        ctrl_fn = F.attitude_control
    else:
        raise ValueError(f"Unsupported control mode: {self.sim.control}")
    disturbances = self.settings.disturbances

    def apply_action(action: Array, data: EnvData) -> None:
        """Apply the commanded state action to the simulation."""
        action = action.reshape((data.sim_data.core.n_worlds, data.sim_data.core.n_drones, -1))
        action = jp.clip(action, action_space.low, action_space.high)
        if "action" in disturbances:
            key, subkey = jax.random.split(data.sim_data.core.rng_key)
            action += disturbances["action"](subkey, action.shape)
            sim_data = data.sim_data.replace(core=data.sim_data.core.replace(rng_key=key))
            data = data.replace(sim_data=sim_data)
        return data.replace(sim_data=ctrl_fn(data.sim_data, action))

    return apply_action
build_contact_check_fn()

Build a function that checks for contacts between drones and gates/obstacles.

Note

Passing the full mjx_data into jit-compiled functions is expensive because the tree contains many elements and is flattened before the jit boundary. To avoid this cost, we fuse mjx_data into the contact_check function and only sync the gate and obstacle poses inside the function. This way, we can only pass EnvData, which is faster to canonicalize.

Source code in lsy_drone_racing/envs/race_core.py
def build_contact_check_fn(self) -> Callable[[EnvData], Array]:
    """Build a function that checks for contacts between drones and gates/obstacles.

    Note:
        Passing the full mjx_data into jit-compiled functions is expensive because the tree
        contains many elements and is flattened **before** the jit boundary. To avoid this cost,
        we fuse mjx_data into the contact_check function and only sync the gate and obstacle
        poses **inside** the function. This way, we can only pass EnvData, which is faster to
        canonicalize.
    """
    contact_masks = _load_contact_masks(self.sim)
    gate_ids, obstacle_ids = self.mocap_ids
    _mjx_data = self.sim.mjx_data

    def check_contacts(data: EnvData) -> Array:
        """Check for contacts between drones and gates/obstacles."""
        mocap_pos, mocap_quat = _mjx_data.mocap_pos, _mjx_data.mocap_quat
        mocap_pos = mocap_pos.at[..., gate_ids, :].set(data.gates_pos)
        mocap_quat = mocap_quat.at[..., gate_ids, :].set(jp.roll(data.gates_quat, 1, axis=-1))
        mocap_pos = mocap_pos.at[..., obstacle_ids, :].set(data.obstacles_pos)
        mjx_data = _mjx_data.replace(mocap_pos=mocap_pos, mocap_quat=mocap_quat)
        # Sync changes to MuJoCo and perform a collision check
        _, mjx_data = sync_sim2mjx(data.sim_data, mjx_data, self.sim.mjx_model)
        contacts = mjx_data._impl.contact.dist < 0
        return jp.any(contacts[:, None, :] & contact_masks, axis=-1)

    return check_contacts
build_render_sync_fn()

Build a function that syncs the environment data with the MuJoCo data for rendering.

Source code in lsy_drone_racing/envs/race_core.py
def build_render_sync_fn(self) -> Callable[[EnvData, Data], Data]:
    """Build a function that syncs the environment data with the MuJoCo data for rendering."""
    gate_ids, obstacle_ids = self.mocap_ids
    mjx_model = self.sim.mjx_model

    @jax.jit
    def render_sync(data: EnvData, mjx_data: Data) -> tuple[EnvData, Data]:
        """Sync the environment data with the MuJoCo data for rendering."""
        gates_pos = data.gates_pos
        gates_quat = data.gates_quat
        obstacles_pos = data.obstacles_pos
        mocap_pos, mocap_quat = mjx_data.mocap_pos, mjx_data.mocap_quat
        mocap_pos = mocap_pos.at[..., gate_ids, :].set(gates_pos)
        mocap_quat = mocap_quat.at[..., gate_ids, :].set(jp.roll(gates_quat, 1, axis=-1))
        mocap_pos = mocap_pos.at[..., obstacle_ids, :].set(obstacles_pos)
        mjx_data = mjx_data.replace(mocap_pos=mocap_pos, mocap_quat=mocap_quat)
        sim_data, mjx_data = sync_sim2mjx(data.sim_data, mjx_data, mjx_model)
        return data.replace(sim_data=sim_data), mjx_data

    return render_sync
build_reset_fn()

Build a function that resets the environment data and simulation data.

Source code in lsy_drone_racing/envs/race_core.py
def build_reset_fn(
    self,
) -> Callable[
    [EnvData, int | None, Array | None], tuple[EnvData, tuple[dict[str, Array], dict]]
]:
    """Build a function that resets the environment data and simulation data."""
    sim_reset_fn = self.sim.build_reset_fn()
    default_sim_data = self.sim.default_data
    randomize_track = build_track_randomization_fn(
        self.settings.randomizations, track=self.track
    )

    @jax.jit
    def reset(
        data: EnvData, seed: int | None = None, mask: Array | None = None
    ) -> tuple[EnvData, tuple[dict[str, Array], dict]]:
        sim_data = data.sim_data
        if seed is not None:
            sim_data = seed_sim(sim_data, seed, sim_data.core.device)
        key, subkey = jax.random.split(sim_data.core.rng_key, 2)
        sim_data = sim_data.replace(core=sim_data.core.replace(rng_key=key))
        # Randomization of the drone is compiled into the sim reset pipeline, so we don't need
        # to explicitly do it here
        sim_data = sim_reset_fn(sim_data, default_sim_data, mask)
        data = data.replace(sim_data=sim_data)
        data = randomize_track(data, mask, subkey)
        data = _reset_env_data(data, mask)
        return data, (obs(data), {})

    return reset
build_step_fn()

Build a function that steps the environment.

Source code in lsy_drone_racing/envs/race_core.py
def build_step_fn(self) -> Callable[[EnvData, Array], EnvData]:
    """Build a function that steps the environment."""
    apply_action_fn = self.build_apply_action_fn()
    contact_check_fn = self.build_contact_check_fn()
    sim_step_fn = self.sim.build_step_fn()
    reset_fn = self._reset
    autoreset = self.settings.autoreset
    max_episode_steps = self.settings.max_episode_steps

    @jax.jit
    def step(data: EnvData, action: Array) -> EnvData:
        # 1) Save marked_for_reset before it is updated. Autoresets need to be based on the
        # previous flags, not the ones from the current step
        marked_for_reset = data.marked_for_reset
        # 2) Register the commanded action in the sim controllers
        data = apply_action_fn(action, data)
        # 3) Step the simulation for the number of sim steps per env step
        n_steps = data.sim_data.core.freq // self.settings.freq
        sim_data = sim_step_fn(data.sim_data, n_steps)
        data = data.replace(sim_data=sim_data)
        # 4) Apply environment logic
        data = _update_disabled_drones(data, contact_check_fn(data))
        data = _warp_disabled_drones(data)  # Prevent interference with alive drones
        data = _update_visited_objects(data)
        data = _update_target_gates(data)
        data = _mark_drones_for_reset(data)
        data = data.replace(steps=data.steps + 1)
        # 5) Auto-reset envs if running with autoreset enabled. Disable for single-world envs
        if autoreset:
            # Only run the reset if at least one env is marked for reset
            data, _ = jax.lax.cond(
                marked_for_reset.any(),
                reset_fn,
                lambda data, *_: (data, (obs(data), {})),
                data,
                None,
                marked_for_reset,
            )
        _truncated = truncated(data, max_episode_steps)
        return data, (obs(data), reward(data), terminated(data), _truncated, {})

    return step
close()

Close the environment by stopping the drone and landing back at the starting position.

Source code in lsy_drone_racing/envs/race_core.py
def close(self):
    """Close the environment by stopping the drone and landing back at the starting position."""
    self.sim.close()
render()

Render the environment.

Source code in lsy_drone_racing/envs/race_core.py
def render(self):
    """Render the environment."""
    if not self.data.sim_data.core.mjx_synced:
        self.data, self.sim.mjx_data = self._render_sync(self.data, self.sim.mjx_data)
    self.sim.render(camera=self.settings.camera, cam_config=self.settings.cam_config)

Functions:

build_action_space(control_mode, drone_model)

Create the action space for the environment.

Parameters:

Name Type Description Default
control_mode Literal['state', 'attitude']

The control mode to use. Either "state" for full-state control or "attitude" for attitude control.

required
drone_model str

Drone model of the environment.

required

Returns:

Type Description
Box

A Box space representing the action space for the specified control mode.

Source code in lsy_drone_racing/envs/race_core.py
def build_action_space(control_mode: Literal["state", "attitude"], drone_model: str) -> spaces.Box:
    """Create the action space for the environment.

    Args:
        control_mode: The control mode to use. Either "state" for full-state control
            or "attitude" for attitude control.
        drone_model: Drone model of the environment.

    Returns:
        A Box space representing the action space for the specified control mode.
    """
    if control_mode == "state":
        return spaces.Box(low=-np.inf, high=np.inf, shape=(13,))
    if control_mode == "attitude":
        params = ForceTorqueParams.load(drone_model)
        thrust_min, thrust_max = params.thrust_min * 4, params.thrust_max * 4
        return spaces.Box(
            np.array([-np.pi / 2, -np.pi / 2, -np.pi / 2, thrust_min], dtype=np.float32),
            np.array([np.pi / 2, np.pi / 2, np.pi / 2, thrust_max], dtype=np.float32),
        )
    raise ValueError(f"Invalid control mode: {control_mode}")

build_drone_reset_fn(randomizations)

Build the reset hook for the simulation.

Source code in lsy_drone_racing/envs/race_core.py
def build_drone_reset_fn(randomizations: dict) -> Callable[[SimData, Array], SimData]:
    """Build the reset hook for the simulation."""
    randomization_fns = ()
    for target, rng in sorted(randomizations.items()):
        match target:
            case "drone_pos":
                randomization_fns += (randomize_drone_pos_fn(rng),)
            case "drone_rpy":
                randomization_fns += (randomize_drone_quat_fn(rng),)
            case "drone_mass":
                randomization_fns += (randomize_drone_mass_fn(rng),)
            case "drone_inertia":
                randomization_fns += (randomize_drone_inertia_fn(rng),)
            case "gate_pos" | "gate_rpy" | "obstacle_pos":
                pass
            case _:
                raise ValueError(f"Invalid target: {target}")

    def reset_fn(data: SimData, mask: Array) -> SimData:
        for randomize_fn in randomization_fns:
            data = randomize_fn(data, mask)
        return data

    return reset_fn

build_dynamics_disturbance_fn(fn)

Build the dynamics disturbance function for the simulation.

Source code in lsy_drone_racing/envs/race_core.py
def build_dynamics_disturbance_fn(
    fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
) -> Callable[[SimData], SimData]:
    """Build the dynamics disturbance function for the simulation."""

    def dynamics_disturbance(data: SimData) -> SimData:
        key, subkey = jax.random.split(data.core.rng_key)
        states = data.states
        states = states.replace(force=fn(subkey, states.force.shape))  # World frame
        return data.replace(states=states, core=data.core.replace(rng_key=key))

    return dynamics_disturbance

build_observation_space(n_gates, n_obstacles)

Create the observation space for the environment.

The observation space is a dictionary containing the drone state, gate information, and obstacle information.

Parameters:

Name Type Description Default
n_gates int

Number of gates in the environment.

required
n_obstacles int

Number of obstacles in the environment.

required
Source code in lsy_drone_racing/envs/race_core.py
def build_observation_space(n_gates: int, n_obstacles: int) -> spaces.Dict:
    """Create the observation space for the environment.

    The observation space is a dictionary containing the drone state, gate information,
    and obstacle information.

    Args:
        n_gates: Number of gates in the environment.
        n_obstacles: Number of obstacles in the environment.
    """
    obs_spec = {
        "pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
        "quat": spaces.Box(low=-1, high=1, shape=(4,)),
        "vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
        "ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
        "target_gate": spaces.Discrete(n_gates, start=-1),
        "gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_gates, 3)),
        "gates_quat": spaces.Box(low=-1, high=1, shape=(n_gates, 4)),
        "gates_visited": spaces.Box(low=0, high=1, shape=(n_gates,), dtype=bool),
        "obstacles_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_obstacles, 3)),
        "obstacles_visited": spaces.Box(low=0, high=1, shape=(n_obstacles,), dtype=bool),
    }
    return spaces.Dict(obs_spec)

build_track_randomization_fn(randomizations, track)

Build the track randomization function for the simulation.

Source code in lsy_drone_racing/envs/race_core.py
def build_track_randomization_fn(
    randomizations: dict, track: ConfigDict
) -> Callable[[EnvData, Array, jax.random.PRNGKey], EnvData]:
    """Build the track randomization function for the simulation."""
    randomization_fns = ()

    if track.randomize:
        random_layout_fn = build_full_track_randomization_fn(
            [gate["pos"][2] for gate in track.gates],
            [obstacle["pos"][2] for obstacle in track.obstacles],
            track.safety_limits.pos_limit_low,
            track.safety_limits.pos_limit_high,
        )
        randomization_fns += (random_layout_fn,)

    for target, rng in sorted(randomizations.items()):
        match target:
            case "gate_pos":
                randomization_fns += (randomize_gate_pos_fn(rng),)
            case "gate_rpy":
                randomization_fns += (randomize_gate_rpy_fn(rng),)
            case "obstacle_pos":
                randomization_fns += (randomize_obstacle_pos_fn(rng),)
            case "drone_pos" | "drone_rpy" | "drone_mass" | "drone_inertia":
                pass
            case _:
                raise ValueError(f"Invalid target: {target}")

    def track_randomization(data: EnvData, mask: Array, key: jax.random.PRNGKey) -> EnvData:
        # Reset to default track positions first
        data = leaf_replace(
            data,
            mask,
            gates_pos=data.gates_pos.at[...].set(data.nominal_gates_pos),
            gates_quat=data.gates_quat.at[...].set(data.nominal_gates_quat),
            obstacles_pos=data.obstacles_pos.at[...].set(data.nominal_obstacles_pos),
        )
        keys = jax.random.split(key, len(randomization_fns))
        for key, randomize_fn in zip(keys, randomization_fns, strict=True):
            data = randomize_fn(data, mask, key)
        return data

    return track_randomization

obs(data)

Return the observation of the environment.

Source code in lsy_drone_racing/envs/race_core.py
def obs(data: EnvData) -> dict[str, Array]:
    """Return the observation of the environment."""
    mask = data.gates_visited[..., None]
    sensor_gates_pos = jp.where(mask, data.gates_pos[:, None], data.nominal_gates_pos[:, None])
    sensor_gates_quat = jp.where(mask, data.gates_quat[:, None], data.nominal_gates_quat[:, None])
    mask = data.obstacles_visited[..., None]
    sensor_obstacles_pos = jp.where(
        mask, data.obstacles_pos[:, None], data.nominal_obstacles_pos[:, None]
    )
    return {
        "pos": data.sim_data.states.pos,
        "quat": data.sim_data.states.quat,
        "vel": data.sim_data.states.vel,
        "ang_vel": data.sim_data.states.ang_vel,
        "target_gate": data.target_gate,
        "gates_pos": sensor_gates_pos,
        "gates_quat": sensor_gates_quat,
        "gates_visited": data.gates_visited,
        "obstacles_pos": sensor_obstacles_pos,
        "obstacles_visited": data.obstacles_visited,
    }

reward(data)

Compute the reward for the current state.

Note

The current sparse reward function will most likely not work directly for training an agent. If you want to use reinforcement learning, you will need to define your own reward function.

Returns:

Type Description
Array

Reward for the current state.

Source code in lsy_drone_racing/envs/race_core.py
def reward(data: EnvData) -> Array:
    """Compute the reward for the current state.

    Note:
        The current sparse reward function will most likely not work directly for training an
        agent. If you want to use reinforcement learning, you will need to define your own
        reward function.

    Returns:
        Reward for the current state.
    """
    return -1.0 * (data.target_gate == -1)  # Implicit float conversion

rng_spec2fn(fn_spec)

Convert a function spec to a wrapped and scaled function from jax.random.

Source code in lsy_drone_racing/envs/race_core.py
def rng_spec2fn(fn_spec: dict) -> Callable:
    """Convert a function spec to a wrapped and scaled function from jax.random."""
    offset, scale = np.array(fn_spec.get("offset", 0)), np.array(fn_spec.get("scale", 1))
    kwargs = fn_spec.get("kwargs", {})
    if "shape" in kwargs:
        raise KeyError("Shape must not be specified for randomization functions.")
    kwargs = {k: np.array(v) if isinstance(v, list) else v for k, v in kwargs.items()}
    jax_fn = partial(getattr(jax.random, fn_spec["fn"]), **kwargs)

    def random_fn(*args: Any, **kwargs: Any) -> Array:
        return jax_fn(*args, **kwargs) * scale + offset

    return random_fn

terminated(data)

Check if the episode is terminated, i.e., if all drones are disabled.

Source code in lsy_drone_racing/envs/race_core.py
def terminated(data: EnvData) -> Array:
    """Check if the episode is terminated, i.e., if all drones are disabled."""
    return data.disabled_drones

truncated(data, max_episode_steps)

Array of booleans indicating if the episode is truncated.

Source code in lsy_drone_racing/envs/race_core.py
def truncated(data: EnvData, max_episode_steps: int) -> Array:
    """Array of booleans indicating if the episode is truncated."""
    n_drones = data.sim_data.core.n_drones
    return jp.tile((data.steps >= max_episode_steps)[..., None], (1, n_drones))