Skip to content

Examples

These examples build on each other — each one introduces one new concept on top of the previous. Start from the top if you're new, or jump to whichever section covers what you need.


Hover

A single drone commanded to hold a fixed height using state control. This is the minimal end-to-end loop: create a Sim, reset it, apply a state command, and step forward.

import numpy as np

from crazyflow.control import Control
from crazyflow.sim import Physics, Sim


def main():
    sim = Sim(
        n_worlds=1,
        n_drones=1,
        physics=Physics.first_principles,
        control=Control.state,
        freq=500,
        attitude_freq=500,
        state_freq=100,
        device="cpu",
    )

    sim.reset()
    duration = 5.0
    fps = 60

    # State cmd is [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]
    cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
    cmd[..., :3] = 0.1

    for i in range(int(duration * sim.control_freq)):
        sim.state_control(cmd)
        sim.step(sim.freq // sim.control_freq)
        if ((i * fps) % sim.control_freq) < fps:
            sim.render()
    sim.close()


if __name__ == "__main__":
    main()
python examples/hover.py

Attitude control

Commanding roll, pitch, yaw, and collective thrust directly. This level bypasses the Mellinger position loop and is typical for RL agents that output attitude targets.

import os

import numpy as np

os.environ["SCIPY_ARRAY_API"] = "1"

from scipy.spatial.transform import Rotation as R

from crazyflow.control import Control
from crazyflow.sim import Sim

kp = np.array([0.4, 0.4, 1.25])
ki = np.array([0.05, 0.05, 0.05])
kd = np.array([0.2, 0.2, 0.4])
g = 9.81


def control(
    t: float, obs: dict[str, np.ndarray], pos_start: np.ndarray, drone_mass: float
) -> np.ndarray:
    des_pos = np.zeros(3)
    des_pos[..., :2] = pos_start[:2] + np.array([np.cos(t) - 1, np.sin(t)])
    des_pos[..., 2] = 0.2 * t
    des_vel = np.zeros_like(des_pos)
    des_yaw = t

    # Calculate the deviations from the desired trajectory
    pos_error = des_pos - np.array(obs["pos"])
    vel_error = des_vel - np.array(obs["vel"])

    # Compute target thrust
    target_thrust = np.zeros(3)
    target_thrust += kp * pos_error
    target_thrust += kd * vel_error
    target_thrust[2] += drone_mass * g

    # Update z_axis to the current orientation of the drone
    z_axis = R.from_quat(obs["quat"]).as_matrix()[:, 2]

    # update current thrust
    thrust_desired = target_thrust.dot(z_axis)

    # update z_axis_desired
    z_axis_desired = target_thrust / np.linalg.norm(target_thrust)
    x_c_des = np.array([np.cos(des_yaw), np.sin(des_yaw), 0.0])
    y_axis_desired = np.cross(z_axis_desired, x_c_des)
    y_axis_desired /= np.linalg.norm(y_axis_desired)
    x_axis_desired = np.cross(y_axis_desired, z_axis_desired)

    R_desired = np.vstack([x_axis_desired, y_axis_desired, z_axis_desired]).T
    euler_desired = R.from_matrix(R_desired).as_euler("xyz", degrees=False)

    action = np.concatenate([euler_desired, [thrust_desired]], dtype=np.float32)

    return action


def main():
    sim = Sim(control=Control.attitude)
    sim.reset()
    duration = 6.5
    fps = 60

    cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))  # [roll, pitch, yaw, thrust]
    pos_start = sim.data.states.pos
    for i in range(int(duration * sim.control_freq)):
        obs = {
            "pos": sim.data.states.pos[0, 0],
            "vel": sim.data.states.vel[0, 0],
            "quat": sim.data.states.quat[0, 0],
        }
        cmd[0, 0, :] = control(
            i / sim.control_freq, obs, pos_start[0, 0], sim.data.params.mass[0, 0, 0]
        )
        sim.attitude_control(cmd)
        sim.step(sim.freq // sim.control_freq)
        if ((i * fps) % sim.control_freq) < fps:
            sim.render()
    sim.close()


if __name__ == "__main__":
    main()

Gradient descent through dynamics

Because the simulator is built entirely from JAX operations, jax.grad can differentiate through it. Starting the drone above the target height keeps it away from the floor, so the floor-clipping stage never fires and gradients flow freely through the entire trajectory.

import time

import jax
import jax.numpy as jnp
from numpy.typing import NDArray

from crazyflow.control import Control
from crazyflow.sim import Physics, Sim
from crazyflow.sim.data import SimData


def main():
    sim = Sim(control=Control.attitude, physics=Physics.first_principles, attitude_freq=50)
    # Remove clipping floor function which kills gradients
    sim.step_pipeline = sim.step_pipeline[:-1]
    sim_step = sim.build_step_fn()

    def step(cmd: NDArray, data: SimData) -> jax.Array:
        data = data.replace(
            controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=cmd))
        )
        data = sim_step(data, 10)
        return (data.states.pos[0, 0, 2] - 1.0) ** 2  # Quadratic cost to reach 1m height

    step_grad = jax.jit(jax.grad(step))

    cmd = jnp.zeros((1, 1, 4), dtype=jnp.float32)
    cmd = cmd.at[..., 3].set(sim.data.params.mass[0, 0, 0] * 9.81 * 1.05)

    # Trigger jax's jit to compile the gradient function. This is not necessary, but it ensures that
    # the timings are not affected by the compilation time.
    step_grad(cmd, sim.data).block_until_ready()
    # JAX compiles again if static properties change. Not sure why this is happening here, but this
    # is a simple way to enforce all recompilations before measuring performance.
    step_grad(cmd - 0.1 * step_grad(cmd, sim.data), sim.data).block_until_ready()

    print(f"Initial command: {cmd}")
    t0 = time.perf_counter()
    for _ in range(10):
        grad = step_grad(cmd, sim.data)
        cmd = cmd - 0.1 * grad
    t1 = time.perf_counter()
    print(f"Loss: {step(cmd, sim.data)}\nGradient: {grad}")

    print(f"Time taken: {t1 - t0:.2e}s ({(t1 - t0) / 10:.2e}s per step)")
    # The final command should increase the z position (3rd array element) as well as the z velocity
    # (6th array element) to minimize the cost function.
    print(f"Final command: {cmd}")


if __name__ == "__main__":
    main()

Domain randomization

Varying physical parameters per world at reset. Each world gets a slightly different mass, so identical commands produce diverging trajectories.

import jax
import numpy as np

from crazyflow.control import Control
from crazyflow.randomize import randomize_inertia, randomize_mass
from crazyflow.sim import Sim
from crazyflow.utils import grid_2d


def main():
    sim = Sim(n_worlds=3, n_drones=4, control=Control.state)
    sim.reset()

    duration = 5.0
    fps = 60

    # Randomize the inertia and mass of the drones
    mask = np.array([True, False, False])  # Only randomize the first world
    mass = sim.data.params.mass
    mass_rng = mass + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 1)) * 5e-3
    J = sim.data.params.J
    J_rng = J + jax.random.normal(jax.random.key(0), (sim.n_worlds, sim.n_drones, 3, 3)) * 1e-6

    randomize_mass(sim, mass_rng, mask)
    # Note: The mask is optional. We can also randomize all worlds at once
    randomize_mass(sim, mass_rng)
    randomize_inertia(sim, J_rng, mask)

    cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
    cmd[..., 2] = 0.4
    cmd[..., :2] = grid_2d(sim.n_drones) * 0.25

    # Simulate for 5 seconds. Each drone should behave slightly differently due to the randomization
    for i in range(int(duration * sim.control_freq)):
        sim.state_control(cmd)
        sim.step(sim.freq // sim.control_freq)
        if ((i * fps) % sim.control_freq) < fps:
            sim.render()
    sim.close()


if __name__ == "__main__":
    main()

Disturbance injection

Inserting a random external force and torque into the step pipeline. The disturbance fires on every physics tick, so the drone fights wind-like perturbations.

import os

import jax
import numpy as np
from numpy.typing import NDArray

from crazyflow.sim import Sim
from crazyflow.sim.data import SimData

os.environ["SCIPY_ARRAY_API"] = "1"

from scipy.spatial.transform import Rotation as R


def disturbance_fn(data: SimData) -> SimData:
    key, subkey = jax.random.split(data.core.rng_key)
    states = data.states
    disturbance_force = jax.random.normal(subkey, states.force.shape) * 0.2  # N, world frame
    states = states.replace(force=disturbance_force)

    key, subkey = jax.random.split(key)
    disturbance_torque = jax.random.normal(subkey, states.torque.shape) * 0.0002  # Nm, world frame
    states = states.replace(torque=disturbance_torque)

    return data.replace(states=states, core=data.core.replace(rng_key=key))


def main(plot: bool = False):
    sim = Sim(control="state")
    control = np.zeros((sim.n_worlds, sim.n_drones, 13))
    control[..., :3] = 0.2

    # First run
    pos, quat = [], []
    sim.reset()
    for _ in range(3 * sim.control_freq):
        sim.state_control(control)
        sim.step(sim.freq // sim.control_freq)
        pos.append(sim.data.states.pos[0, 0])
        quat.append(sim.data.states.quat[0, 0])
        sim.render()

    # Second run
    # We insert the disturbance function into the step pipeline before the integration step. You can
    # inspect the step pipeline with
    # print(sim.step_pipeline)
    sim.step_pipeline = sim.step_pipeline[:2] + (disturbance_fn,) + sim.step_pipeline[2:]
    sim.build_step_fn()
    pos_disturbed, quat_disturbed = [], []
    sim.reset()
    for _ in range(3 * sim.control_freq):
        sim.state_control(control)
        sim.step(sim.freq // sim.control_freq)
        pos_disturbed.append(sim.data.states.pos[0, 0])
        quat_disturbed.append(sim.data.states.quat[0, 0])
        sim.render()

    sim.close()
    if plot:
        plot_results(pos, pos_disturbed, quat, quat_disturbed)


def plot_results(
    pos: list[NDArray],
    pos_disturbed: list[NDArray],
    quat: list[NDArray],
    quat_disturbed: list[NDArray],
):
    # Only import if plotting is desired to avoid a dependency on matplotlib
    import matplotlib.pyplot as plt  # noqa: F401

    pos, pos_disturbed = np.array(pos), np.array(pos_disturbed)
    rpy = R.from_quat(quat).as_euler("xyz")
    rpy_disturbed = R.from_quat(quat_disturbed).as_euler("xyz")
    fig, ax = plt.subplots(3, 2, sharex="all", figsize=(10, 6))
    t = np.linspace(0, 3, len(pos))
    # XYZ position
    ax[0, 0].plot(t, pos[:, 0], label="x undisturbed", color="r")
    ax[0, 0].plot(t, pos_disturbed[:, 0], label="x disturbed", color="r", linestyle="--")
    ax[1, 0].plot(t, pos[:, 1], label="y undisturbed", color="g")
    ax[1, 0].plot(t, pos_disturbed[:, 1], label="y perturbed", color="g", linestyle="--")
    ax[2, 0].plot(t, pos[:, 2], label="z undisturbed", color="b")
    ax[2, 0].plot(t, pos_disturbed[:, 2], label="z disturbed", color="b", linestyle="--")
    # RPY angles
    ax[0, 1].plot(t, rpy[:, 0], label="roll undisturbed", color="r")
    ax[0, 1].plot(t, rpy_disturbed[:, 0], label="roll disturbed", color="r", linestyle="--")
    ax[1, 1].plot(t, rpy[:, 1], label="pitch undisturbed", color="g")
    ax[1, 1].plot(t, rpy_disturbed[:, 1], label="pitch disturbed", color="g", linestyle="--")
    ax[2, 1].plot(t, rpy[:, 2], label="yaw undisturbed", color="b")
    ax[2, 1].plot(t, rpy_disturbed[:, 2], label="yaw disturbed", color="b", linestyle="--")
    fig.suptitle("Dynamics with disturbance")
    ax[2, 0].set_xlabel("Time (s)")
    ax[2, 1].set_xlabel("Time (s)")
    for _ax in ax.flatten():
        _ax.legend()
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main(plot=True)  # Default is False to disable plotting during testing

Cameras and RGBD

Offscreen rendering returns RGB and depth images on every frame. The FPV camera (fpv_cam) is attached to the drone and moves with it.

RGB and depth camera outputs from a Crazyflow drone simulation
import time

import matplotlib.pyplot as plt
import mujoco
import numpy as np
from matplotlib import animation

from crazyflow.control import Control
from crazyflow.sim import Sim
from crazyflow.sim.integration import Integrator
from crazyflow.sim.physics import Physics


def control(t: float, t_tot: float) -> np.ndarray:
    phi = 2 * np.pi * t / t_tot + np.pi
    circle = np.array([np.cos(phi), np.sin(phi)])
    cmd = np.zeros((1, 1, 13))
    cmd[..., :2] = circle  # xy
    cmd[..., 2] = 0.1 + 0.5 * t / t_tot  # z
    cmd[..., -4] = 1.9 * np.pi * t / t_tot  # yaw

    return cmd


def add_smiley(sim: Sim):
    # Add 3d object to sim
    # create box spec from an XML string
    box_xml = """
    <mujoco model="box_model">
      <worldbody>
        <body name="cube" pos="0 0 0">
          <geom type="box" size="0.05 0.05 0.05" rgba="0.8 0.4 0.2 1"/>
        </body>
      </worldbody>
    </mujoco>
    """
    box_spec = mujoco.MjSpec.from_string(box_xml)
    frame = sim.spec.worldbody.add_frame()
    boxes = [
        # eyes
        ((0.0, -0.15, 0.6), (1, 0, 0, 0)),
        ((0.0, 0.15, 0.6), (1, 0, 0, 0)),
        # mouth
        ((0.0, -0.2, 0.4), (1, 0, 0, 0)),
        ((0.0, 0.2, 0.4), (1, 0, 0, 0)),
        ((0.0, -0.1, 0.3), (1, 0, 0, 0)),
        ((0.0, 0.0, 0.3), (1, 0, 0, 0)),
        ((0.0, 0.1, 0.3), (1, 0, 0, 0)),
    ]
    for i, x in enumerate(boxes):
        box_body = box_spec.body("cube")
        box = frame.attach_body(box_body, "", f":{i}")
        box.pos = x[0]
        box.quat = x[1]
    sim.build_mjx()
    sim.build_reset_fn()


def main(show_plot: bool = False, save_plot: bool = False):
    """Example showing the rendering feature and saving a gif via FuncAnimation."""
    # Setup sim
    sim = Sim(
        n_drones=1,
        control=Control.state,
        integrator=Integrator.rk4,
        physics=Physics.first_principles,
        drone_model="cf2x_T350",
    )
    add_smiley(sim)
    sim.reset()
    pos = sim.data.states.pos.at[...].set([-1, 0, 0])
    states = sim.data.states.replace(pos=pos)
    sim.data = sim.data.replace(states=states)
    duration = 8
    fps = 25
    timings = []

    # Set up matplotlib rendering
    resolution = (160, 120)
    rgb = np.zeros((resolution[1], resolution[0], 3))
    d = np.zeros((resolution[1], resolution[0]))
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    im1 = ax1.imshow(rgb)
    ax1.set_title("RGB")
    ax1.axis("off")
    im2 = ax2.imshow(d, cmap="viridis")
    ax2.set_title("Depth")
    ax2.axis("off")
    fig.tight_layout()

    # Animation setup
    def update_frame(_):  # noqa: ANN202
        t = sim.data.core.steps[0, 0] / sim.freq
        sim.state_control(control(t, duration))
        sim.step(sim.freq // fps)

        t1 = time.perf_counter()
        rgbd = sim.render(
            width=resolution[0], height=resolution[1], mode="rgbd_tuple", camera="fpv_cam:0"
        )
        t2 = time.perf_counter()
        timings.append(t2 - t1)
        if rgbd is None:
            return im1, im2
        rgb, depth = rgbd
        im1.set_data(rgb)
        im2.set_data(depth)
        im2.set_clim(np.nanmin(depth), np.nanmax(depth))
        return im1, im2

    anim = animation.FuncAnimation(fig, update_frame, frames=int(duration * fps), blit=True)
    if show_plot:
        plt.show()  # this is slow
    if save_plot:
        anim.save("cameras.gif", writer="pillow", fps=fps)

    sim.close()

    t_mean = np.mean(timings)
    print(f"Average render time {t_mean * 1000}ms, eqivalent to {1 / t_mean}fps")


if __name__ == "__main__":
    main(show_plot=True, save_plot=False)
python examples/cameras.py

LED deck and materials

change_material updates the RGBA colour and emission of any named material on any subset of drones at runtime.

Crazyflow drones with runtime-controlled LED deck materials
import tempfile
from pathlib import Path

import numpy as np

from crazyflow.control.control import Control
from crazyflow.sim import Sim
from crazyflow.sim.visualize import change_material

scene_dark_xml = """
<mujoco model="Drone scene">
    <option integrator="RK4" density="1.225" viscosity="1.8e-5" timestep="0.001"/>
    <compiler inertiafromgeom="false" meshdir="assets" autolimits="true"/>
    <statistic center="0 0 2" extent="2.5"/>

    <visual>
        <rgba haze="0.15 0.25 0.35 0" fog="1 1 1 0"/>
        <map fogstart="0" fogend="0"/>
        <global azimuth="-20" elevation="-20" ellipsoidinertia="true"/>
    </visual>

    <asset>
        <texture type="skybox" builtin="gradient" rgb1="0.3 0.5 0.7" rgb2="0 0 0" width="512" height="3072"/>
        <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3"
        markrgb="0.8 0.8 0.8" width="512" height="512"/>
        <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
    </asset>

    <worldbody>
        <geom name="floor" size="0 0 0.05" type="plane" material="groundplane"/>
    </worldbody>
</mujoco>
"""  # noqa: E501


def main():
    """Spawn 25 drones in one world and activate led decks."""
    try:
        # Use a named temporary xml file
        with tempfile.NamedTemporaryFile(suffix=".xml", delete=False) as tmp:
            tmp.write(scene_dark_xml.encode())
            tmp.flush()
            tmp_path = Path(tmp.name)

        sim = Sim(n_drones=25, drone_model="cf21B_500", control=Control.state, xml_path=tmp_path)
        fps = 60
        cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
        cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81
        rgbas = np.random.default_rng(0).uniform(0, 1, (sim.n_drones, 4))
        rgbas[..., 3] = 1.0

        init_pos = np.array(sim.data.states.pos[0, :, :])
        cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
        cmd[:, :, :3] = init_pos
        cmd[:, :, 2] += 1.5

        for i in range(int(10 * sim.control_freq)):
            sim.state_control(cmd)
            sim.step(sim.freq // sim.control_freq)
            if ((i * fps) % sim.control_freq) < fps:
                even_ids = np.arange(0, sim.n_drones, 2)
                odd_ids = np.arange(1, sim.n_drones, 2)
                emission = np.sin(i / sim.control_freq * np.pi)
                change_material(
                    sim,
                    mat_name="led_top",
                    drone_ids=even_ids,
                    rgba=rgbas[even_ids, :],
                    emission=emission,
                )
                change_material(
                    sim,
                    mat_name="led_bot",
                    drone_ids=odd_ids,
                    rgba=rgbas[odd_ids, :],
                    emission=emission,
                )
                sim.render()
        sim.close()
    finally:
        # clean up the temporary file
        if tmp_path is not None and tmp_path.exists():
            try:
                tmp_path.unlink()
            except Exception:
                pass


if __name__ == "__main__":
    main()
python examples/led_deck.py

Contact queries

The default collision geometry is a sphere around the drone frame. use_box_collision replaces it with a tighter oriented box, useful for narrow-gap flight and accurate contact debugging.

Contact query visualization using the default sphere collision geometry
Contact query visualization using the oriented box collision geometry
import numpy as np

from crazyflow.sim import Physics, Sim
from crazyflow.sim.sim import use_box_collision


def main():
    """Spawn multiple drones in multiple worlds and check for contacts."""
    n_worlds, n_drones = 2, 3
    sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device="cpu")
    use_box_collision(sim, enable=True)  # Enable box collision for all drones
    fps = 60

    cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
    cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81 * 1.04
    for i in range(int(2 * sim.control_freq)):
        sim.attitude_control(cmd)
        sim.step(sim.freq // sim.control_freq)
        if ((i * fps) % sim.control_freq) < fps:
            sim.render()
            print(f"Contacts: {sim.contacts().any()}")
    sim.close()


if __name__ == "__main__":
    main()

Raycasting and depth sensing

render_depth fires rays from a camera and returns per-pixel distances. This is faster than full RGB rendering and useful for obstacle sensing or depth-based controllers.

import jax.numpy as jnp
import matplotlib.pyplot as plt

from crazyflow.sim import Sim
from crazyflow.sim.sensors import build_render_depth_fn, render_depth


def main(plot: bool = False):
    sim = Sim()
    sim.data = sim.data.replace(
        states=sim.data.states.replace(pos=sim.data.states.pos.at[..., 2].set(0.2))
    )
    # The easiest way to get depth images is to use the render_depth function
    dist = render_depth(sim, camera=0, resolution=(100, 100), include_drone=False)
    dist = dist.at[dist > 1.5].set(jnp.nan)  # Cap max distance for better visualization
    if plot:
        plt.imshow(dist[0], cmap="viridis")
        plt.colorbar(label="Distance (m)")
        plt.title("Raycast Distance from Camera")
        plt.show()
    # We can also build a depth renderer function for better performance if we need maximum speed or
    # more fine-grained control. Here we only render the drone collision geometry to avoid expensive
    # raycasting against the high-poly visual mesh of the drone.
    render_depth_fn = build_render_depth_fn(
        sim.mjx_model, camera=0, resolution=(200, 200), geomgroup=(1, 1, 0, 1, 1, 1, 1, 1)
    )
    dist_fn = render_depth_fn(sim)
    dist_fn = dist_fn.at[dist_fn > 1.5].set(jnp.nan)  # Cap max distance for better visualization
    if plot:
        plt.imshow(dist_fn[0], cmap="viridis")
        plt.colorbar(label="Distance (m)")
        plt.title("Raycast Distance from Camera (Compiled)")
        plt.show()


if __name__ == "__main__":
    main(plot=True)
python examples/raycasting.py

Gymnasium environment

Evaluating a random policy in the figure-8 environment. The env wraps Sim behind the standard Gymnasium VectorEnv interface.

import gymnasium
import jax.numpy as jnp
import numpy as np
from gymnasium.wrappers.vector import JaxToNumpy  # , JaxToTorch

from crazyflow.envs import NormalizeActions  # noqa: F401
from crazyflow.utils import enable_cache


def main():
    enable_cache()
    # Create environment that contains a figure eight trajectory. You can parametrize the
    # observation space, i.e., which part of the trajectory is contained in the observation. Please
    # refer to the documentation of the environment for more information.
    envs = gymnasium.make_vec(
        "DroneFigureEightTrajectory-v0",
        num_envs=20,
        freq=50,
        n_samples=10,
        samples_dt=0.1,
        trajectory_time=10.0,
    )

    # NormalizeActions wrapper to clip the actions to [-1, 1] and rescale them for use with common
    # DRL libraries.
    envs = NormalizeActions(envs)
    envs = JaxToNumpy(envs)

    # dummy action for going up (in attitude control)
    action = np.zeros((20, 4), dtype=np.float32)
    action[..., 3] = 0.3

    obs, info = envs.reset()
    # Step through the environment
    for _ in range(1_000):
        # Prevent alignment warnings. Related issue: https://github.com/jax-ml/jax/issues/29810
        # TODO: Remove once https://github.com/jax-ml/jax/pull/29963 is merged.
        action = np.asarray(jnp.asarray(action))
        observation, reward, terminated, truncated, info = envs.step(action)
        envs.render()

    envs.close()


if __name__ == "__main__":
    main()