Tetris (JAX)

Description

Details

Action Space

int in [0, 6]

Observation Space

jax.Array of shape (height, width), dtype int8

Import

from tetris_gymnasium.envs.tetris_fn import reset, step

Description

The functional environment is a pure-function, JAX-native implementation of Tetris. Unlike the Gymnasium-based environment, it has no class hierarchy, no gym.Env API, and no internal mutable state. Every operation is a stateless function that takes the current state as input and returns a new state — making it fully compatible with jax.jit, jax.vmap, and jax.grad.

Key benefits over the standard environment:

  • JIT compilation — wrap reset or step with jax.jit for fast repeated calls.

  • Vectorised environments — use batched_reset / batched_step (or jax.vmap) to run thousands of environments in parallel on a single device.

  • Functional purity — no hidden state; reproducibility is guaranteed by the PRNG key.

  • No Gymnasium dependency — suitable for JAX-only training pipelines.

Configuration

The environment is configured through an EnvConfig named tuple:

class tetris_gymnasium.functional.core.EnvConfig(width: int, height: int, padding: int, queue_size: int, gravity_enabled: bool = True)[source]

Configuration for the Tetris environment.

width

The width of the game board.

Type:

int

height

The height of the game board.

Type:

int

padding

The padding around the game board.

Type:

int

queue_size

The size of the tetromino queue.

Type:

int

Create new instance of EnvConfig(width, height, padding, queue_size, gravity_enabled)

gravity_enabled: bool

Alias for field number 4

height: int

Alias for field number 1

padding: int

Alias for field number 2

queue_size: int

Alias for field number 3

width: int

Alias for field number 0

Default values used throughout this page:

from tetris_gymnasium.functional.core import EnvConfig

config = EnvConfig(width=10, height=20, padding=4, queue_size=7)

State

All game state is held in a State dataclass:

class tetris_gymnasium.functional.core.State(rng_key: Array, board: Array | ndarray | bool | number, active_tetromino: Array | ndarray | bool | number, rotation: Array | ndarray | bool | number, x: Array | ndarray | bool | number, y: Array | ndarray | bool | number, queue: Array | ndarray | bool | number, queue_index: Array | ndarray | bool | number, game_over: Array | ndarray | bool | number, score: Array | ndarray | bool | number)[source]

State of the Tetris environment.

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values

Field

Shape

Description

board

(H, W)

Padded board; bedrock padding has value 1

active_tetromino

()

Index into the tetrominoes array (0–6)

rotation

()

Current rotation of the active piece (0–3)

x

()

Column position of the active piece

y

()

Row position of the active piece

queue

(L,)

Piece queue (indices into tetrominoes array)

queue_index

()

Current position in the queue

game_over

()

Boolean flag

score

()

Cumulative score

rng_key

(2,)

JAX PRNG key for internal randomness

Tetrominoes

The standard set of seven tetrominoes is provided as a pre-built TETROMINOES constant:

from tetris_gymnasium.functional.tetrominoes import TETROMINOES

Pass this object to reset, step, and all core functions.

Basic Usage

reset

import jax
from tetris_gymnasium.envs.tetris_fn import reset
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES

config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
key = jax.random.PRNGKey(42)

key, state, observation = reset(TETROMINOES, key, config)

reset returns a 3-tuple:

Return value

Type

Description

key

PRNGKey

Updated PRNG key (pass to the next call)

state

State

Initial game state

observation

jax.Array

Initial observation of shape (height, width)

step

from tetris_gymnasium.envs.tetris_fn import step

action = 0  # move left
state, observation, reward, terminated, info = step(
    TETROMINOES, state, action, config
)

step returns a 5-tuple:

Return value

Type

Description

state

State

Updated game state

observation

jax.Array

New observation of shape (height, width)

reward

float

Reward for this step (new_score old_score)

terminated

bool

True when the game is over

info

dict

{"lines_cleared": int}

When terminated is True, subsequent calls to step are no-ops — state and observation are returned unchanged with zero reward.

Random agent example

"""Minimal example: random agent using the JAX functional Tetris environment."""

import jax
import jax.numpy as jnp

from tetris_gymnasium.envs.tetris_fn import reset, step
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES

config = EnvConfig(width=10, height=20, padding=4, queue_size=7)

key = jax.random.PRNGKey(42)
key, state, observation = reset(TETROMINOES, key, config)

terminated = False
total_reward = 0.0

while not terminated:
    key, subkey = jax.random.split(key)
    action = int(jax.random.randint(subkey, shape=(), minval=0, maxval=7))
    state, observation, reward, terminated, info = step(
        TETROMINOES, state, action, config
    )
    total_reward += float(reward)

print(f"Game over! Score: {total_reward:.0f}")

Actions

ID

Name

Effect

0

move_left

Move active piece one column left

1

move_right

Move active piece one column right

2

move_down

Move active piece one row down (soft drop)

3

rotate_counterclockwise

Rotate active piece 90° counter-clockwise

4

rotate_clockwise

Rotate active piece 90° clockwise

5

do_nothing

No movement; gravity still applies if enabled

6

hard_drop

Drop piece instantly to lowest valid position

Moves that would result in a collision are silently ignored (piece stays in place).

Rewards

Event

Reward formula

Lines cleared

max(lines * 200 100, 0)

Tetris (4 lines)

800 (flat bonus)

Hard drop distance

2 × cells dropped

The reward returned by step is the delta of state.score between steps, so it already combines all sources for that step.

Observation

The observation is a 2D integer array of shape (height, width) (padding stripped):

Value

Meaning

0

Empty cell

1

Locked piece

-1

Active (falling) piece

tetris_gymnasium.envs.tetris_fn.get_observation(board, x, y, active_tetromino, rotation, game_over, tetrominoes: Tetrominoes, config: EnvConfig) Array | ndarray | bool | number[source]

Returns the observation of the environment.

JIT Compilation

Wrap reset and step with jax.jit to compile them once and execute fast:

import jax
from functools import partial
from tetris_gymnasium.envs.tetris_fn import reset, step
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES

config = EnvConfig(width=10, height=20, padding=4, queue_size=7)

jit_reset = jax.jit(partial(reset, TETROMINOES, config=config))
jit_step  = jax.jit(partial(step,  TETROMINOES, config=config))

key = jax.random.PRNGKey(0)
key, state, obs = jit_reset(key)
state, obs, reward, terminated, info = jit_step(state, 6)

Batched (Vectorised) Environments

Run multiple independent environments in parallel using batched_reset and batched_step. These functions use jax.vmap internally and are JIT-compiled automatically.

import jax
import jax.numpy as jnp
from tetris_gymnasium.envs.tetris_fn import batched_reset, batched_step
from tetris_gymnasium.functional.core import EnvConfig
from tetris_gymnasium.functional.tetrominoes import TETROMINOES

config = EnvConfig(width=10, height=20, padding=4, queue_size=7)
BATCH = 64

keys = jax.random.split(jax.random.PRNGKey(0), BATCH)
keys, states, observations = batched_reset(
    TETROMINOES, keys, config=config, batch_size=BATCH
)

# All environments take a different action
actions = jnp.zeros(BATCH, dtype=jnp.int32)
states, observations, rewards, terminated, info = batched_step(
    TETROMINOES, states, actions, config=config
)

All state fields gain a leading batch dimension of size BATCH.

tetris_gymnasium.envs.tetris_fn.batched_reset(tetrominoes: Tetrominoes, keys: Array, *, config: EnvConfig, create_queue_fn: Callable[[EnvConfig, Array], Tuple[Array | ndarray | bool | number, int]] = create_bag_queue, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element, batch_size: int = 1) Tuple[Array, Array | ndarray | bool | number, State][source]

Vectorized version of reset function that handles batches.

tetris_gymnasium.envs.tetris_fn.batched_step(tetrominoes: Tetrominoes, states: State, actions: Array | ndarray | bool | number, *, config: EnvConfig, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element) Tuple[State, Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number, dict][source]

Vectorized version of step function that handles batches of states.

API Reference

tetris_gymnasium.envs.tetris_fn.reset(tetrominoes: Tetrominoes, key: Array, config: EnvConfig, create_queue_fn: Callable[[EnvConfig, Array], Tuple[Array | ndarray | bool | number, int]] = create_bag_queue, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element) Tuple[Array, State, Array | ndarray | bool | number][source]

Resets the Tetris environment to its initial state.

tetris_gymnasium.envs.tetris_fn.step(tetrominoes: Tetrominoes, state: State, action: int, config: EnvConfig, queue_fn: Callable[[EnvConfig, Array | ndarray | bool | number, int, Array], Tuple[int, Array | ndarray | bool | number, int, Array]] = bag_queue_get_next_element) Tuple[State, Array | ndarray | bool | number, float, bool, dict][source]

Performs a single step in the Tetris environment.

Core functions

tetris_gymnasium.functional.core.create_board(config: EnvConfig, tetrominoes: Tetrominoes) Array | ndarray | bool | number[source]

Creates an empty Tetris board with padding.

Parameters:
  • config – Environment configuration.

  • tetrominoes – Tetrominoes object containing tetromino configurations.

Returns:

A 2D array representing the empty Tetris board with padding.

tetris_gymnasium.functional.core.collision(board: Array | ndarray | bool | number, tetromino: Array | ndarray | bool | number, x: int, y: int) Array | ndarray | bool | number[source]

Checks if there’s a collision between the tetromino and the board at the given position.

Parameters:
  • board – The current state of the Tetris board.

  • tetromino – The tetromino to check for collision.

  • x – The x-coordinate of the tetromino’s position.

  • y – The y-coordinate of the tetromino’s position.

Returns:

A boolean indicating whether there’s a collision.

tetris_gymnasium.functional.core.hard_drop(board: Array | ndarray | bool | number, tetromino: Array | ndarray | bool | number, x: int, y: int) tuple[source]

Performs a hard drop of the tetromino, moving it down as far as possible.

Parameters:
  • board – The current state of the Tetris board.

  • tetromino – The tetromino to drop.

  • x – The x-coordinate of the tetromino’s position.

  • y – The y-coordinate of the tetromino’s position.

Returns:

The final y-coordinate after the hard drop.

tetris_gymnasium.functional.core.lock_active_tetromino(config: EnvConfig, tetrominoes: Tetrominoes, board, active_tetromino, rotation, x, y) Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number][source]

Locks the active tetromino in place, clears any filled rows, and calculates the reward.

Parameters:
  • config – Environment configuration.

  • tetrominoes – Tetrominoes object containing tetromino configurations.

  • board – The current state of the Tetris board.

  • active_tetromino – ID of the active tetromino.

  • rotation – Current rotation of the active tetromino.

  • x – Current x-coordinate of the active tetromino.

  • y – Current y-coordinate of the active tetromino.

Returns:

A tuple containing the updated board and the reward.

tetris_gymnasium.functional.core.clear_filled_rows(config: EnvConfig, tetrominoes: Tetrominoes, board: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number][source]

Clears filled rows from the board and returns the updated board and number of cleared rows.

Parameters:
  • config – Environment configuration.

  • tetrominoes – Tetrominoes object containing tetromino configurations.

  • board – The current state of the Tetris board.

Returns:

A tuple containing the updated board and the number of cleared rows.

tetris_gymnasium.functional.core.check_game_over(tetrominoes: Tetrominoes, board, active_tetromino, rotation, x, y) bool[source]

Checks if the game is over by determining if the new tetromino collides immediately.

Parameters:
  • tetrominoes – Tetrominoes object containing tetromino configurations.

  • board – The current state of the Tetris board.

  • active_tetromino – ID of the active tetromino.

  • rotation – Current rotation of the active tetromino.

  • x – Current x-coordinate of the active tetromino.

  • y – Current y-coordinate of the active tetromino.

Returns:

A boolean indicating whether the game is over.

tetris_gymnasium.functional.core.score(config: EnvConfig, rows_cleared: int) uint8[source]

Calculates the score based on the number of rows cleared.

Parameters:
  • config – Environment configuration.

  • rows_cleared – The number of rows cleared.

Returns:

The calculated score as a uint8.