Tetris (JAX)¶
Description |
Details |
|---|---|
Action Space |
|
Observation Space |
|
Import |
|
Description¶
The functional environment is a pure-function, JAX-native implementation of Tetris.
Unlike the Gymnasium-based environment, it has no class hierarchy, no gym.Env API,
and no internal mutable state.
Every operation is a stateless function that takes the current state as input and
returns a new state — making it fully compatible with jax.jit, jax.vmap, and
jax.grad.
Key benefits over the standard environment:
JIT compilation — wrap
resetorstepwithjax.jitfor fast repeated calls.Vectorised environments — use
batched_reset/batched_step(orjax.vmap) to run thousands of environments in parallel on a single device.Functional purity — no hidden state; reproducibility is guaranteed by the PRNG key.
No Gymnasium dependency — suitable for JAX-only training pipelines.
Configuration¶
The environment is configured through an EnvConfig named tuple:
- class tetris_gymnasium.functional.core.EnvConfig(width: int, height: int, padding: int, queue_size: int, gravity_enabled: bool = True)[source]
Configuration for the Tetris environment.
- width
The width of the game board.
- Type:
int
- height
The height of the game board.
- Type:
int
- padding
The padding around the game board.
- Type:
int
- queue_size
The size of the tetromino queue.
- Type:
int
Create new instance of EnvConfig(width, height, padding, queue_size, gravity_enabled)
- gravity_enabled: bool
Alias for field number 4
- height: int
Alias for field number 1
- padding: int
Alias for field number 2
- queue_size: int
Alias for field number 3
- width: int
Alias for field number 0
Default values used throughout this page:
from tetris_gymnasium.functional.core import EnvConfig
config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
State¶
All game state is held in a State dataclass:
- class tetris_gymnasium.functional.core.State(rng_key: Array, board: Array | ndarray | bool | number, active_tetromino: Array | ndarray | bool | number, rotation: Array | ndarray | bool | number, x: Array | ndarray | bool | number, y: Array | ndarray | bool | number, queue: Array | ndarray | bool | number, queue_index: Array | ndarray | bool | number, game_over: Array | ndarray | bool | number, score: Array | ndarray | bool | number)[source]¶
State of the Tetris environment.
- items() a set-like object providing a view on D's items¶
- keys() a set-like object providing a view on D's keys¶
- values() an object providing a view on D's values¶
Field |
Shape |
Description |
|---|---|---|
|
(H, W) |
Padded board; bedrock padding has value 1 |
|
() |
Index into the tetrominoes array (0–6) |
|
() |
Current rotation of the active piece (0–3) |
|
() |
Column position of the active piece |
|
() |
Row position of the active piece |
|
(L,) |
Piece queue (indices into tetrominoes array) |
|
() |
Current position in the queue |
|
() |
Boolean flag |
|
() |
Cumulative score |
|
(2,) |
JAX PRNG key for internal randomness |
Tetrominoes¶
The standard set of seven tetrominoes is provided as a pre-built TETROMINOES constant:
from tetris_gymnasium.functional.tetrominoes import TETROMINOES
Pass this object to reset, step, and all core functions.
Basic Usage¶
reset¶
import jax
from tetris_gymnasium.envs.tetris_fn import reset
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES
config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
key = jax.random.PRNGKey(42)
key, state, observation = reset(TETROMINOES, key, config)
reset returns a 3-tuple:
Return value |
Type |
Description |
|---|---|---|
|
|
Updated PRNG key (pass to the next call) |
|
|
Initial game state |
|
|
Initial observation of shape |
step¶
from tetris_gymnasium.envs.tetris_fn import step
action = 0 # move left
state, observation, reward, terminated, info = step(
TETROMINOES, state, action, config
)
step returns a 5-tuple:
Return value |
Type |
Description |
|---|---|---|
|
|
Updated game state |
|
|
New observation of shape |
|
|
Reward for this step ( |
|
|
|
|
|
|
When terminated is True, subsequent calls to step are no-ops — state and
observation are returned unchanged with zero reward.
Random agent example¶
"""Minimal example: random agent using the JAX functional Tetris environment."""
import jax
import jax.numpy as jnp
from tetris_gymnasium.envs.tetris_fn import reset, step
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES
config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
key = jax.random.PRNGKey(42)
key, state, observation = reset(TETROMINOES, key, config)
terminated = False
total_reward = 0.0
while not terminated:
key, subkey = jax.random.split(key)
action = int(jax.random.randint(subkey, shape=(), minval=0, maxval=7))
state, observation, reward, terminated, info = step(
TETROMINOES, state, action, config
)
total_reward += float(reward)
print(f"Game over! Score: {total_reward:.0f}")
Actions¶
ID |
Name |
Effect |
|---|---|---|
0 |
|
Move active piece one column left |
1 |
|
Move active piece one column right |
2 |
|
Move active piece one row down (soft drop) |
3 |
|
Rotate active piece 90° counter-clockwise |
4 |
|
Rotate active piece 90° clockwise |
5 |
|
No movement; gravity still applies if enabled |
6 |
|
Drop piece instantly to lowest valid position |
Moves that would result in a collision are silently ignored (piece stays in place).
Rewards¶
Event |
Reward formula |
|---|---|
Lines cleared |
|
Tetris (4 lines) |
800 (flat bonus) |
Hard drop distance |
|
The reward returned by step is the delta of state.score between steps, so it
already combines all sources for that step.
Observation¶
The observation is a 2D integer array of shape (height, width) (padding stripped):
Value |
Meaning |
|---|---|
|
Empty cell |
|
Locked piece |
|
Active (falling) piece |
JIT Compilation¶
Wrap reset and step with jax.jit to compile them once and execute fast:
import jax
from functools import partial
from tetris_gymnasium.envs.tetris_fn import reset, step
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES
config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
jit_reset = jax.jit(partial(reset, TETROMINOES, config=config))
jit_step = jax.jit(partial(step, TETROMINOES, config=config))
key = jax.random.PRNGKey(0)
key, state, obs = jit_reset(key)
state, obs, reward, terminated, info = jit_step(state, 6)
Batched (Vectorised) Environments¶
Run multiple independent environments in parallel using batched_reset and
batched_step.
These functions use jax.vmap internally and are JIT-compiled automatically.
import jax
import jax.numpy as jnp
from tetris_gymnasium.envs.tetris_fn import batched_reset, batched_step
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES
config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
BATCH = 64
keys = jax.random.split(jax.random.PRNGKey(0), BATCH)
keys, states, observations = batched_reset(
TETROMINOES, keys, config=config, batch_size=BATCH
)
# All environments take a different action
actions = jnp.zeros(BATCH, dtype=jnp.int32)
states, observations, rewards, terminated, info = batched_step(
TETROMINOES, states, actions, config=config
)
All state fields gain a leading batch dimension of size BATCH.
- tetris_gymnasium.envs.tetris_fn.batched_reset(tetrominoes: Tetrominoes, keys: Array, *, config: EnvConfig, create_queue_fn: Callable[[EnvConfig, Array], Tuple[Array | ndarray | bool | number, int]] = create_bag_queue, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element, batch_size: int = 1) Tuple[Array, Array | ndarray | bool | number, State][source]¶
Vectorized version of reset function that handles batches.
- tetris_gymnasium.envs.tetris_fn.batched_step(tetrominoes: Tetrominoes, states: State, actions: Array | ndarray | bool | number, *, config: EnvConfig, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element) Tuple[State, Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number, dict][source]¶
Vectorized version of step function that handles batches of states.
API Reference¶
- tetris_gymnasium.envs.tetris_fn.reset(tetrominoes: Tetrominoes, key: Array, config: EnvConfig, create_queue_fn: Callable[[EnvConfig, Array], Tuple[Array | ndarray | bool | number, int]] = create_bag_queue, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element) Tuple[Array, State, Array | ndarray | bool | number][source]¶
Resets the Tetris environment to its initial state.
- tetris_gymnasium.envs.tetris_fn.step(tetrominoes: Tetrominoes, state: State, action: int, config: EnvConfig, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element) Tuple[State, Array | ndarray | bool | number, float, bool, dict][source]¶
Performs a single step in the Tetris environment.
Core functions¶
- tetris_gymnasium.functional.core.create_board(config: EnvConfig, tetrominoes: Tetrominoes) Array | ndarray | bool | number[source]¶
Creates an empty Tetris board with padding.
- Parameters:
config – Environment configuration.
tetrominoes – Tetrominoes object containing tetromino configurations.
- Returns:
A 2D array representing the empty Tetris board with padding.
- tetris_gymnasium.functional.core.collision(board: Array | ndarray | bool | number, tetromino: Array | ndarray | bool | number, x: int, y: int) Array | ndarray | bool | number[source]¶
Checks if there’s a collision between the tetromino and the board at the given position.
- Parameters:
board – The current state of the Tetris board.
tetromino – The tetromino to check for collision.
x – The x-coordinate of the tetromino’s position.
y – The y-coordinate of the tetromino’s position.
- Returns:
A boolean indicating whether there’s a collision.
- tetris_gymnasium.functional.core.hard_drop(board: Array | ndarray | bool | number, tetromino: Array | ndarray | bool | number, x: int, y: int) tuple[source]¶
Performs a hard drop of the tetromino, moving it down as far as possible.
- Parameters:
board – The current state of the Tetris board.
tetromino – The tetromino to drop.
x – The x-coordinate of the tetromino’s position.
y – The y-coordinate of the tetromino’s position.
- Returns:
The final y-coordinate after the hard drop.
- tetris_gymnasium.functional.core.lock_active_tetromino(config: EnvConfig, tetrominoes: Tetrominoes, board, active_tetromino, rotation, x, y) Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number][source]¶
Locks the active tetromino in place, clears any filled rows, and calculates the reward.
- Parameters:
config – Environment configuration.
tetrominoes – Tetrominoes object containing tetromino configurations.
board – The current state of the Tetris board.
active_tetromino – ID of the active tetromino.
rotation – Current rotation of the active tetromino.
x – Current x-coordinate of the active tetromino.
y – Current y-coordinate of the active tetromino.
- Returns:
A tuple containing the updated board and the reward.
- tetris_gymnasium.functional.core.clear_filled_rows(config: EnvConfig, tetrominoes: Tetrominoes, board: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number][source]¶
Clears filled rows from the board and returns the updated board and number of cleared rows.
- Parameters:
config – Environment configuration.
tetrominoes – Tetrominoes object containing tetromino configurations.
board – The current state of the Tetris board.
- Returns:
A tuple containing the updated board and the number of cleared rows.
- tetris_gymnasium.functional.core.check_game_over(tetrominoes: Tetrominoes, board, active_tetromino, rotation, x, y) bool[source]¶
Checks if the game is over by determining if the new tetromino collides immediately.
- Parameters:
tetrominoes – Tetrominoes object containing tetromino configurations.
board – The current state of the Tetris board.
active_tetromino – ID of the active tetromino.
rotation – Current rotation of the active tetromino.
x – Current x-coordinate of the active tetromino.
y – Current y-coordinate of the active tetromino.
- Returns:
A boolean indicating whether the game is over.