Examples¶
These examples build on each other — each one introduces one new concept on top of the previous. Start from the top if you're new, or jump to whichever section covers what you need.
Hover¶
A single drone commanded to hold a fixed height using state control. This is the minimal end-to-end loop: create a Sim, reset it, apply a state command, and step forward.
import numpy as np
from crazyflow.control import Control
from crazyflow.sim import Physics, Sim
def main():
sim = Sim(
n_worlds=1,
n_drones=1,
physics=Physics.first_principles,
control=Control.state,
freq=500,
attitude_freq=500,
state_freq=100,
device="cpu",
)
sim.reset()
duration = 5.0
fps = 60
# State cmd is [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]
cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
cmd[..., :3] = 0.1
for i in range(int(duration * sim.control_freq)):
sim.state_control(cmd)
sim.step(sim.freq // sim.control_freq)
if ((i * fps) % sim.control_freq) < fps:
sim.render()
sim.close()
if __name__ == "__main__":
main()
Attitude control¶
Commanding roll, pitch, yaw, and collective thrust directly. This level bypasses the Mellinger position loop and is typical for RL agents that output attitude targets.
import os
import numpy as np
os.environ["SCIPY_ARRAY_API"] = "1"
from scipy.spatial.transform import Rotation as R
from crazyflow.control import Control
from crazyflow.sim import Sim
kp = np.array([0.4, 0.4, 1.25])
ki = np.array([0.05, 0.05, 0.05])
kd = np.array([0.2, 0.2, 0.4])
g = 9.81
def control(
t: float, obs: dict[str, np.ndarray], pos_start: np.ndarray, drone_mass: float
) -> np.ndarray:
des_pos = np.zeros(3)
des_pos[..., :2] = pos_start[:2] + np.array([np.cos(t) - 1, np.sin(t)])
des_pos[..., 2] = 0.2 * t
des_vel = np.zeros_like(des_pos)
des_yaw = t
# Calculate the deviations from the desired trajectory
pos_error = des_pos - np.array(obs["pos"])
vel_error = des_vel - np.array(obs["vel"])
# Compute target thrust
target_thrust = np.zeros(3)
target_thrust += kp * pos_error
target_thrust += kd * vel_error
target_thrust[2] += drone_mass * g
# Update z_axis to the current orientation of the drone
z_axis = R.from_quat(obs["quat"]).as_matrix()[:, 2]
# update current thrust
thrust_desired = target_thrust.dot(z_axis)
# update z_axis_desired
z_axis_desired = target_thrust / np.linalg.norm(target_thrust)
x_c_des = np.array([np.cos(des_yaw), np.sin(des_yaw), 0.0])
y_axis_desired = np.cross(z_axis_desired, x_c_des)
y_axis_desired /= np.linalg.norm(y_axis_desired)
x_axis_desired = np.cross(y_axis_desired, z_axis_desired)
R_desired = np.vstack([x_axis_desired, y_axis_desired, z_axis_desired]).T
euler_desired = R.from_matrix(R_desired).as_euler("xyz", degrees=False)
action = np.concatenate([euler_desired, [thrust_desired]], dtype=np.float32)
return action
def main():
sim = Sim(control=Control.attitude)
sim.reset()
duration = 6.5
fps = 60
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4)) # [roll, pitch, yaw, thrust]
pos_start = sim.data.states.pos
for i in range(int(duration * sim.control_freq)):
obs = {
"pos": sim.data.states.pos[0, 0],
"vel": sim.data.states.vel[0, 0],
"quat": sim.data.states.quat[0, 0],
}
cmd[0, 0, :] = control(
i / sim.control_freq, obs, pos_start[0, 0], sim.data.params.mass[0, 0, 0]
)
sim.attitude_control(cmd)
sim.step(sim.freq // sim.control_freq)
if ((i * fps) % sim.control_freq) < fps:
sim.render()
sim.close()
if __name__ == "__main__":
main()
Gradient descent through dynamics¶
Because the simulator is built entirely from JAX operations, jax.grad can differentiate through it. Starting the drone above the target height keeps it away from the floor, so the floor-clipping stage never fires and gradients flow freely through the entire trajectory.
import time
import jax
import jax.numpy as jnp
from numpy.typing import NDArray
from crazyflow.control import Control
from crazyflow.sim import Physics, Sim
from crazyflow.sim.data import SimData
def main():
sim = Sim(control=Control.attitude, physics=Physics.first_principles, attitude_freq=50)
# Remove clipping floor function which kills gradients
sim.step_pipeline = sim.step_pipeline[:-1]
sim_step = sim.build_step_fn()
def step(cmd: NDArray, data: SimData) -> jax.Array:
data = data.replace(
controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=cmd))
)
data = sim_step(data, 10)
return (data.states.pos[0, 0, 2] - 1.0) ** 2 # Quadratic cost to reach 1m height
step_grad = jax.jit(jax.grad(step))
cmd = jnp.zeros((1, 1, 4), dtype=jnp.float32)
cmd = cmd.at[..., 3].set(sim.data.params.mass[0, 0, 0] * 9.81 * 1.05)
# Trigger jax's jit to compile the gradient function. This is not necessary, but it ensures that
# the timings are not affected by the compilation time.
step_grad(cmd, sim.data).block_until_ready()
# JAX compiles again if static properties change. Not sure why this is happening here, but this
# is a simple way to enforce all recompilations before measuring performance.
step_grad(cmd - 0.1 * step_grad(cmd, sim.data), sim.data).block_until_ready()
print(f"Initial command: {cmd}")
t0 = time.perf_counter()
for _ in range(10):
grad = step_grad(cmd, sim.data)
cmd = cmd - 0.1 * grad
t1 = time.perf_counter()
print(f"Loss: {step(cmd, sim.data)}\nGradient: {grad}")
print(f"Time taken: {t1 - t0:.2e}s ({(t1 - t0) / 10:.2e}s per step)")
# The final command should increase the z position (3rd array element) as well as the z velocity
# (6th array element) to minimize the cost function.
print(f"Final command: {cmd}")
if __name__ == "__main__":
main()
Domain randomization¶
Varying physical parameters per world at reset. Each world gets a slightly different mass, so identical commands produce diverging trajectories.
import jax
import numpy as np
from crazyflow.control import Control
from crazyflow.randomize import randomize_inertia, randomize_mass
from crazyflow.sim import Sim
from crazyflow.utils import grid_2d
def main():
sim = Sim(n_worlds=3, n_drones=4, control=Control.state)
sim.reset()
duration = 5.0
fps = 60
# Randomize the inertia and mass of the drones
mask = np.array([True, False, False]) # Only randomize the first world
mass = sim.data.params.mass
mass_rng = mass + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 1)) * 5e-3
J = sim.data.params.J
J_rng = J + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 3, 3)) * 1e-6
randomize_mass(sim, mass_rng, mask)
# Note: The mask is optional. We can also randomize all worlds at once
randomize_mass(sim, mass_rng)
randomize_inertia(sim, J_rng, mask)
cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
cmd[..., 2] = 0.4
cmd[..., :2] = grid_2d(sim.n_drones) * 0.25
# Simulate for 5 seconds. Each drone should behave slightly differently due to the randomization
for i in range(int(duration * sim.control_freq)):
sim.state_control(cmd)
sim.step(sim.freq // sim.control_freq)
if ((i * fps) % sim.control_freq) < fps:
sim.render()
sim.close()
if __name__ == "__main__":
main()
Disturbance injection¶
Inserting a random external force and torque into the step pipeline. The disturbance fires on every physics tick, so the drone fights wind-like perturbations.
import os
import jax
import numpy as np
from numpy.typing import NDArray
from crazyflow.sim import Sim
from crazyflow.sim.data import SimData
os.environ["SCIPY_ARRAY_API"] = "1"
from scipy.spatial.transform import Rotation as R
def disturbance_fn(data: SimData) -> SimData:
key, subkey = jax.random.split(data.core.rng_key)
states = data.states
disturbance_force = jax.random.normal(subkey, states.force.shape) * 0.2 # N, world frame
states = states.replace(force=disturbance_force)
key, subkey = jax.random.split(key)
disturbance_torque = jax.random.normal(subkey, states.torque.shape) * 0.0002 # Nm, world frame
states = states.replace(torque=disturbance_torque)
return data.replace(states=states, core=data.core.replace(rng_key=key))
def main(plot: bool = False):
sim = Sim(control="state")
control = np.zeros((sim.n_worlds, sim.n_drones, 13))
control[..., :3] = 0.2
# First run
pos, quat = [], []
sim.reset()
for _ in range(3 * sim.control_freq):
sim.state_control(control)
sim.step(sim.freq // sim.control_freq)
pos.append(sim.data.states.pos[0, 0])
quat.append(sim.data.states.quat[0, 0])
sim.render()
# Second run
# We insert the disturbance function into the step pipeline before the integration step. You can
# inspect the step pipeline with
# print(sim.step_pipeline)
sim.step_pipeline = sim.step_pipeline[:2] + (disturbance_fn,) + sim.step_pipeline[2:]
sim.build_step_fn()
pos_disturbed, quat_disturbed = [], []
sim.reset()
for _ in range(3 * sim.control_freq):
sim.state_control(control)
sim.step(sim.freq // sim.control_freq)
pos_disturbed.append(sim.data.states.pos[0, 0])
quat_disturbed.append(sim.data.states.quat[0, 0])
sim.render()
sim.close()
if plot:
plot_results(pos, pos_disturbed, quat, quat_disturbed)
def plot_results(
pos: list[NDArray],
pos_disturbed: list[NDArray],
quat: list[NDArray],
quat_disturbed: list[NDArray],
):
# Only import if plotting is desired to avoid a dependency on matplotlib
import matplotlib.pyplot as plt # noqa: F401
pos, pos_disturbed = np.array(pos), np.array(pos_disturbed)
rpy = R.from_quat(quat).as_euler("xyz")
rpy_disturbed = R.from_quat(quat_disturbed).as_euler("xyz")
fig, ax = plt.subplots(3, 2, sharex="all", figsize=(10, 6))
t = np.linspace(0, 3, len(pos))
# XYZ position
ax[0, 0].plot(t, pos[:, 0], label="x undisturbed", color="r")
ax[0, 0].plot(t, pos_disturbed[:, 0], label="x disturbed", color="r", linestyle="--")
ax[1, 0].plot(t, pos[:, 1], label="y undisturbed", color="g")
ax[1, 0].plot(t, pos_disturbed[:, 1], label="y perturbed", color="g", linestyle="--")
ax[2, 0].plot(t, pos[:, 2], label="z undisturbed", color="b")
ax[2, 0].plot(t, pos_disturbed[:, 2], label="z disturbed", color="b", linestyle="--")
# RPY angles
ax[0, 1].plot(t, rpy[:, 0], label="roll undisturbed", color="r")
ax[0, 1].plot(t, rpy_disturbed[:, 0], label="roll disturbed", color="r", linestyle="--")
ax[1, 1].plot(t, rpy[:, 1], label="pitch undisturbed", color="g")
ax[1, 1].plot(t, rpy_disturbed[:, 1], label="pitch disturbed", color="g", linestyle="--")
ax[2, 1].plot(t, rpy[:, 2], label="yaw undisturbed", color="b")
ax[2, 1].plot(t, rpy_disturbed[:, 2], label="yaw disturbed", color="b", linestyle="--")
fig.suptitle("Dynamics with disturbance")
ax[2, 0].set_xlabel("Time (s)")
ax[2, 1].set_xlabel("Time (s)")
for _ax in ax.flatten():
_ax.legend()
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main(plot=True) # Default is False to disable plotting during testing
Cameras and RGBD¶
Offscreen rendering returns RGB and depth images on every frame. The FPV camera (fpv_cam) is attached to the drone and moves with it.
import time
import matplotlib.pyplot as plt
import mujoco
import numpy as np
from matplotlib import animation
from crazyflow.control import Control
from crazyflow.sim import Sim
from crazyflow.sim.integration import Integrator
from crazyflow.sim.physics import Physics
def control(t: float, t_tot: float) -> np.ndarray:
phi = 2 * np.pi * t / t_tot + np.pi
circle = np.array([np.cos(phi), np.sin(phi)])
cmd = np.zeros((1, 1, 13))
cmd[..., :2] = circle # xy
cmd[..., 2] = 0.1 + 0.5 * t / t_tot # z
cmd[..., -4] = 1.9 * np.pi * t / t_tot # yaw
return cmd
def add_smiley(sim: Sim):
# Add 3d object to sim
# create box spec from an XML string
box_xml = """
<mujoco model="box_model">
<worldbody>
<body name="cube" pos="0 0 0">
<geom type="box" size="0.05 0.05 0.05" rgba="0.8 0.4 0.2 1"/>
</body>
</worldbody>
</mujoco>
"""
box_spec = mujoco.MjSpec.from_string(box_xml)
frame = sim.spec.worldbody.add_frame()
boxes = [
# eyes
((0.0, -0.15, 0.6), (1, 0, 0, 0)),
((0.0, 0.15, 0.6), (1, 0, 0, 0)),
# mouth
((0.0, -0.2, 0.4), (1, 0, 0, 0)),
((0.0, 0.2, 0.4), (1, 0, 0, 0)),
((0.0, -0.1, 0.3), (1, 0, 0, 0)),
((0.0, 0.0, 0.3), (1, 0, 0, 0)),
((0.0, 0.1, 0.3), (1, 0, 0, 0)),
]
for i, x in enumerate(boxes):
box_body = box_spec.body("cube")
box = frame.attach_body(box_body, "", f":{i}")
box.pos = x[0]
box.quat = x[1]
sim.build_mjx()
sim.build_reset_fn()
def main(show_plot: bool = False, save_plot: bool = False):
"""Example showing the rendering feature and saving a gif via FuncAnimation."""
# Setup sim
sim = Sim(
n_drones=1,
control=Control.state,
integrator=Integrator.rk4,
physics=Physics.first_principles,
drone_model="cf2x_T350",
)
add_smiley(sim)
sim.reset()
pos = sim.data.states.pos.at[...].set([-1, 0, 0])
states = sim.data.states.replace(pos=pos)
sim.data = sim.data.replace(states=states)
duration = 8
fps = 25
timings = []
# Set up matplotlib rendering
resolution = (160, 120)
rgb = np.zeros((resolution[1], resolution[0], 3))
d = np.zeros((resolution[1], resolution[0]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
im1 = ax1.imshow(rgb)
ax1.set_title("RGB")
ax1.axis("off")
im2 = ax2.imshow(d, cmap="viridis")
ax2.set_title("Depth")
ax2.axis("off")
fig.tight_layout()
# Animation setup
def update_frame(_): # noqa: ANN202
t = sim.data.core.steps[0, 0] / sim.freq
sim.state_control(control(t, duration))
sim.step(sim.freq // fps)
t1 = time.perf_counter()
rgbd = sim.render(
width=resolution[0], height=resolution[1], mode="rgbd_tuple", camera="fpv_cam:0"
)
t2 = time.perf_counter()
timings.append(t2 - t1)
if rgbd is None:
return im1, im2
rgb, depth = rgbd
im1.set_data(rgb)
im2.set_data(depth)
im2.set_clim(np.nanmin(depth), np.nanmax(depth))
return im1, im2
anim = animation.FuncAnimation(fig, update_frame, frames=int(duration * fps), blit=True)
if show_plot:
plt.show() # this is slow
if save_plot:
anim.save("cameras.gif", writer="pillow", fps=fps)
sim.close()
t_mean = np.mean(timings)
print(f"Average render time {t_mean * 1000}ms, eqivalent to {1 / t_mean}fps")
if __name__ == "__main__":
main(show_plot=True, save_plot=False)
LED deck and materials¶
change_material updates the RGBA colour and emission of any named material on any subset of drones at runtime.
import tempfile
from pathlib import Path
import numpy as np
from crazyflow.control.control import Control
from crazyflow.sim import Sim
from crazyflow.sim.visualize import change_material
scene_dark_xml = """
<mujoco model="Drone scene">
<option integrator="RK4" density="1.225" viscosity="1.8e-5" timestep="0.001"/>
<compiler inertiafromgeom="false" meshdir="assets" autolimits="true"/>
<statistic center="0 0 2" extent="2.5"/>
<visual>
<rgba haze="0.15 0.25 0.35 0" fog="1 1 1 0"/>
<map fogstart="0" fogend="0"/>
<global azimuth="-20" elevation="-20" ellipsoidinertia="true"/>
</visual>
<asset>
<texture type="skybox" builtin="gradient" rgb1="0.3 0.5 0.7" rgb2="0 0 0" width="512" height="3072"/>
<texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3"
markrgb="0.8 0.8 0.8" width="512" height="512"/>
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
</asset>
<worldbody>
<geom name="floor" size="0 0 0.05" type="plane" material="groundplane"/>
</worldbody>
</mujoco>
""" # noqa: E501
def main():
"""Spawn 25 drones in one world and activate led decks."""
try:
# Use a named temporary xml file
with tempfile.NamedTemporaryFile(suffix=".xml", delete=False) as tmp:
tmp.write(scene_dark_xml.encode())
tmp.flush()
tmp_path = Path(tmp.name)
sim = Sim(n_drones=25, drone_model="cf21B_500", control=Control.state, xml_path=tmp_path)
fps = 60
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81
rgbas = np.random.default_rng(0).uniform(0, 1, (sim.n_drones, 4))
rgbas[..., 3] = 1.0
init_pos = np.array(sim.data.states.pos[0, :, :])
cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
cmd[:, :, :3] = init_pos
cmd[:, :, 2] += 1.5
for i in range(int(10 * sim.control_freq)):
sim.state_control(cmd)
sim.step(sim.freq // sim.control_freq)
if ((i * fps) % sim.control_freq) < fps:
even_ids = np.arange(0, sim.n_drones, 2)
odd_ids = np.arange(1, sim.n_drones, 2)
emission = np.sin(i / sim.control_freq * np.pi)
change_material(
sim,
mat_name="led_top",
drone_ids=even_ids,
rgba=rgbas[even_ids, :],
emission=emission,
)
change_material(
sim,
mat_name="led_bot",
drone_ids=odd_ids,
rgba=rgbas[odd_ids, :],
emission=emission,
)
sim.render()
sim.close()
finally:
# clean up the temporary file
if tmp_path is not None and tmp_path.exists():
try:
tmp_path.unlink()
except Exception:
pass
if __name__ == "__main__":
main()
Contact queries¶
The default collision geometry is a sphere around the drone frame. use_box_collision replaces it with a tighter oriented box, useful for narrow-gap flight and accurate contact debugging.
import numpy as np
from crazyflow.sim import Physics, Sim
from crazyflow.sim.sim import use_box_collision
def main():
"""Spawn multiple drones in multiple worlds and check for contacts."""
n_worlds, n_drones = 2, 3
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device="cpu")
use_box_collision(sim, enable=True) # Enable box collision for all drones
fps = 60
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81 * 1.04
for i in range(int(2 * sim.control_freq)):
sim.attitude_control(cmd)
sim.step(sim.freq // sim.control_freq)
if ((i * fps) % sim.control_freq) < fps:
sim.render()
print(f"Contacts: {sim.contacts().any()}")
sim.close()
if __name__ == "__main__":
main()
Raycasting and depth sensing¶
render_depth fires rays from a camera and returns per-pixel distances. This is faster than full RGB rendering and useful for obstacle sensing or depth-based controllers.
import jax.numpy as jnp
import matplotlib.pyplot as plt
from crazyflow.sim import Sim
from crazyflow.sim.sensors import build_render_depth_fn, render_depth
def main(plot: bool = False):
sim = Sim()
sim.data = sim.data.replace(
states=sim.data.states.replace(pos=sim.data.states.pos.at[..., 2].set(0.2))
)
# The easiest way to get depth images is to use the render_depth function
dist = render_depth(sim, camera=0, resolution=(100, 100), include_drone=False)
dist = dist.at[dist > 1.5].set(jnp.nan) # Cap max distance for better visualization
if plot:
plt.imshow(dist[0], cmap="viridis")
plt.colorbar(label="Distance (m)")
plt.title("Raycast Distance from Camera")
plt.show()
# We can also build a depth renderer function for better performance if we need maximum speed or
# more fine-grained control. Here we only render the drone collision geometry to avoid expensive
# raycasting against the high-poly visual mesh of the drone.
render_depth_fn = build_render_depth_fn(
sim.mjx_model, camera=0, resolution=(200, 200), geomgroup=(1, 1, 0, 1, 1, 1, 1, 1)
)
dist_fn = render_depth_fn(sim)
dist_fn = dist_fn.at[dist_fn > 1.5].set(jnp.nan) # Cap max distance for better visualization
if plot:
plt.imshow(dist_fn[0], cmap="viridis")
plt.colorbar(label="Distance (m)")
plt.title("Raycast Distance from Camera (Compiled)")
plt.show()
if __name__ == "__main__":
main(plot=True)
Gymnasium environment¶
Evaluating a random policy in the figure-8 environment. The env wraps Sim behind the standard Gymnasium VectorEnv interface.
import gymnasium
import jax.numpy as jnp
import numpy as np
from gymnasium.wrappers.vector import JaxToNumpy # , JaxToTorch
from crazyflow.envs import NormalizeActions # noqa: F401
from crazyflow.utils import enable_cache
def main():
enable_cache()
# Create environment that contains a figure eight trajectory. You can parametrize the
# observation space, i.e., which part of the trajectory is contained in the observation. Please
# refer to the documentation of the environment for more information.
envs = gymnasium.make_vec(
"DroneFigureEightTrajectory-v0",
num_envs=20,
freq=50,
n_samples=10,
samples_dt=0.1,
trajectory_time=10.0,
)
# NormalizeActions wrapper to clip the actions to [-1, 1] and rescale them for use with common
# DRL libraries.
envs = NormalizeActions(envs)
envs = JaxToNumpy(envs)
# dummy action for going up (in attitude control)
action = np.zeros((20, 4), dtype=np.float32)
action[..., 3] = 0.3
obs, info = envs.reset()
# Step through the environment
for _ in range(1_000):
# Prevent alignment warnings. Related issue: https://github.com/jax-ml/jax/issues/29810
# TODO: Remove once https://github.com/jax-ml/jax/pull/29963 is merged.
action = np.asarray(jnp.asarray(action))
observation, reward, terminated, truncated, info = envs.step(action)
envs.render()
envs.close()
if __name__ == "__main__":
main()