Source code for tetris_gymnasium.envs.tetris_fn

"""Wrappers for the Tetris environment implemented as pure functions."""
from functools import partial
from typing import Tuple

import chex
import jax
import jax.numpy as jnp
from jax import random, vmap

from tetris_gymnasium.functional.core import (
    EnvConfig,
    State,
    check_game_over,
    collision,
    create_board,
    get_initial_x_y,
    graviy_step,
    hard_drop,
    lock_active_tetromino,
    project_tetromino,
)
from tetris_gymnasium.functional.queue import (
    CreateQueueFunction,
    QueueFunction,
    bag_queue_get_next_element,
    create_bag_queue,
)
from tetris_gymnasium.functional.tetrominoes import Tetrominoes, get_tetromino_matrix


def get_feature_observation(
    board,
    x,
    y,
    config: EnvConfig,
) -> chex.Array:
    """Returns a vector with Tetris board features and adds uniform random noise.

    Features:
    1. Individual column heights (10 features)
    2. Column differences (9 features)
    3. Holes count (1 feature)
    4. Maximum height (1 feature)
    5. Position of current tetromino (x, y)
    6. Additional uniform random noise using a feature-based RNG key

    Args:
        board: The Tetris board.
        x: Current x position of active tetromino.
        y: Current y position of active tetromino.
        config: Environment configuration.

    Returns:
        A feature vector combining board features with uniform random noise.
    """
    # Get the playable area of the board
    playable_board = board[0 : -config.padding, config.padding : -config.padding]

    # Get board dimensions
    height, width = playable_board.shape

    # Calculate heights of each column
    def get_column_height(col_idx):
        # Extract the column
        column = playable_board[:, col_idx]
        # Find first non-zero element (if any)
        non_zero_indices = jnp.where(column > 0, jnp.arange(height), height)
        # Get the index of the highest occupied cell
        highest_occupied = jnp.min(non_zero_indices)
        # Convert to height from bottom
        return height - highest_occupied

    # Apply the function to each column
    heights = vmap(get_column_height)(jnp.arange(width))

    # Feature: Maximum height (highest column)
    maximum_height = jnp.max(heights)

    # Feature: Column differences (bumpiness measure)
    height_differences = jnp.abs(heights[1:] - heights[:-1])

    # Feature: Holes (empty cells with filled cells above them)
    def count_holes_in_column(col_idx):
        column = playable_board[:, col_idx]

        # Use JAX scan to count holes (empty cells with filled cells above them)
        def scan_fn(carry, x):
            # carry: (have_seen_filled_cell, hole_count)
            # x: current cell value
            seen_filled, holes = carry

            # If we've seen a filled cell above and current cell is empty, it's a hole
            is_filled = x > 0
            is_hole = seen_filled & (x == 0)

            # Once we see a filled cell, any empty cell below is potentially a hole
            seen_filled = seen_filled | is_filled
            holes = holes + jnp.int32(is_hole)

            return (seen_filled, holes), None

        # Run the scan from top to bottom of the column
        (_, total_holes), _ = jax.lax.scan(scan_fn, (False, 0), column)
        return total_holes

    holes = jnp.sum(vmap(count_holes_in_column)(jnp.arange(width)))

    # Combine features into a basic feature vector
    base_features = jnp.concatenate(
        [
            heights,  # 10 column heights
            height_differences,  # 9 column differences
            jnp.array([holes]),  # 1 holes count
            jnp.array([maximum_height]),  # 1 maximum height
            jnp.array([x, y]),  # 2 position features
        ],
        axis=0,
    )

    # Create an RNG key from the sum of the base features
    # First ensure it's an integer and take the absolute value to avoid negative values
    feature_sum = jnp.abs(jnp.sum(base_features)).astype(jnp.int32)

    # Use the feature sum to create an RNG key
    # We need to fold it into a pair of integers for the JAX RNG
    key = jax.random.PRNGKey(feature_sum)

    # Generate uniform random noise and append to the feature vector
    # Generating a single uniform random value between 0 and 1
    random_feature = jax.random.uniform(
        key, shape=base_features.shape, minval=0, maxval=10
    )

    return random_feature


[docs] def get_observation( board, x, y, active_tetromino, rotation, game_over, tetrominoes: Tetrominoes, config: EnvConfig, ) -> chex.Array: """Returns the observation of the environment.""" tetromino_matrix = get_tetromino_matrix(tetrominoes, active_tetromino, rotation) board = jnp.where(board > 0, 1, 0).astype(jnp.int8) result = jax.lax.cond( jnp.logical_not(game_over), lambda _: project_tetromino(board, tetromino_matrix, x, y, -1), lambda _: board, None, ) return result[0 : -config.padding, config.padding : -config.padding]
def update_state(action, config, state, queue_fn, tetrominoes): """Update the state of the environment based on the given action.""" x, y, rotation = state.x, state.y, state.rotation board = state.board active_tetromino_matrix = get_tetromino_matrix( tetrominoes, state.active_tetromino, rotation ) def move_left(): return jax.lax.cond( ~collision(board, active_tetromino_matrix, x - 1, y), lambda: x - 1, lambda: x, ) def move_right(): return jax.lax.cond( ~collision(board, active_tetromino_matrix, x + 1, y), lambda: x + 1, lambda: x, ) def move_down(): new_y = jax.lax.cond( ~collision(board, active_tetromino_matrix, x, y + 1), lambda: y + 1, lambda: y, ) move_reward = jnp.int32(new_y - y) return new_y, move_reward def rotate_clockwise(): new_rotation = (rotation + 1) % 4 new_matrix = get_tetromino_matrix( tetrominoes, state.active_tetromino, new_rotation ) return jax.lax.cond( ~collision(board, new_matrix, x, y), lambda: (new_rotation, new_matrix), lambda: (rotation, active_tetromino_matrix), ) def rotate_counterclockwise(): new_rotation = (rotation - 1) % 4 new_matrix = get_tetromino_matrix( tetrominoes, state.active_tetromino, new_rotation ) return jax.lax.cond( ~collision(board, new_matrix, x, y), lambda: (new_rotation, new_matrix), lambda: (rotation, active_tetromino_matrix), ) x = jax.lax.switch( action, [move_left, move_right, lambda: x, lambda: x, lambda: x, lambda: x, lambda: x], ) y, drop_reward = jax.lax.switch( action, [ lambda: (y, 0), lambda: (y, 0), move_down, lambda: (y, 0), lambda: (y, 0), lambda: (y, 0), lambda: hard_drop(board, active_tetromino_matrix, x, y), ], ) rotation, active_tetromino_matrix = jax.lax.switch( action, [ lambda: (rotation, active_tetromino_matrix), lambda: (rotation, active_tetromino_matrix), lambda: (rotation, active_tetromino_matrix), rotate_counterclockwise, rotate_clockwise, lambda: (rotation, active_tetromino_matrix), lambda: (rotation, active_tetromino_matrix), ], ) y_gravity = jax.lax.cond( config.gravity_enabled, lambda: graviy_step(tetrominoes, board, state.active_tetromino, rotation, x, y), lambda: y, ) should_lock = (y_gravity == y) & config.gravity_enabled # Create intermediate state with updated position and rotation intermediate_state = State( rng_key=state.rng_key, board=board, active_tetromino=state.active_tetromino, rotation=rotation, x=x, y=y_gravity, queue=state.queue, queue_index=state.queue_index, game_over=False, score=state.score, ) # Handle locking and new piece spawning new_state, lock_reward, lines_cleared = jax.lax.cond( (should_lock | (action == 6)), lambda: place_active_tetromino( config, tetrominoes, intermediate_state, queue_fn ), lambda: (intermediate_state, 0, 0), ) # Update score new_state = new_state.replace(score=new_state.score + drop_reward + lock_reward) return new_state, lines_cleared
[docs] def step( tetrominoes: Tetrominoes, state: State, action: int, config: EnvConfig, queue_fn: QueueFunction = bag_queue_get_next_element, ) -> Tuple[State, chex.Array, float, bool, dict]: """Performs a single step in the Tetris environment.""" new_state, lines_cleared = jax.lax.cond( state.game_over, lambda _: (state, 0), lambda _: update_state(action, config, state, queue_fn, tetrominoes), None, ) # new_observation = get_feature_observation( # new_state.board, # new_state.x, # new_state.y, # config, # ) new_observation = get_observation( new_state.board, new_state.x, new_state.y, new_state.active_tetromino, new_state.rotation, new_state.game_over, tetrominoes, config, ) return ( new_state, # state new_observation, # observation new_state.score - state.score, # reward new_state.game_over, # terminated {"lines_cleared": lines_cleared}, # info )
[docs] def reset( tetrominoes: Tetrominoes, key: chex.PRNGKey, config: EnvConfig, create_queue_fn: CreateQueueFunction = create_bag_queue, queue_fn: QueueFunction = bag_queue_get_next_element, ) -> Tuple[chex.PRNGKey, State, chex.Array]: """Resets the Tetris environment to its initial state.""" board = create_board(config, tetrominoes) key, subkey = random.split(key) queue, queue_index = create_queue_fn(config, key) active_tetromino, queue, queue_index, key = queue_fn( config, queue, queue_index, key ) x, y = get_initial_x_y(config, tetrominoes, active_tetromino) state = State( rng_key=subkey, board=board, active_tetromino=active_tetromino, rotation=0, x=x, y=y, queue=queue, queue_index=queue_index, game_over=False, score=jnp.float32(0), ) # observation = get_feature_observation( # state.board, # state.x, # state.y, # config, # ) observation = get_observation( state.board, state.x, state.y, state.active_tetromino, state.rotation, state.game_over, tetrominoes, config, ) return key, state, observation
def place_active_tetromino( config: EnvConfig, tetrominoes: Tetrominoes, state: State, queue_fn: QueueFunction, ) -> Tuple[State, chex.Array, chex.Array]: """Places the active tetromino on the board and updates the game state.""" new_board, reward, lines_cleared = lock_active_tetromino( config, tetrominoes, state.board, state.active_tetromino, state.rotation, state.x, state.y, ) # Spawn a new tetromino new_active_tetromino, new_queue, new_queue_index, _ = queue_fn( config, state.queue, state.queue_index, state.rng_key ) new_x, new_y = get_initial_x_y(config, tetrominoes, new_active_tetromino) new_rotation = 0 # Check if the game is over game_over = check_game_over( tetrominoes, new_board, new_active_tetromino, new_rotation, new_x, new_y ) new_rng_key = random.split(state.rng_key)[0] new_state = State( board=new_board, active_tetromino=new_active_tetromino, rotation=new_rotation, x=new_x, y=new_y, queue=new_queue, queue_index=new_queue_index, game_over=game_over, score=state.score, rng_key=new_rng_key, ) return new_state, reward, lines_cleared
[docs] def batched_step( tetrominoes: Tetrominoes, states: State, actions: chex.Array, # [B] *, # Force config to be a keyword argument config: EnvConfig, queue_fn: QueueFunction = bag_queue_get_next_element, ) -> Tuple[State, chex.Array, chex.Array, chex.Array, dict]: """Vectorized version of step function that handles batches of states.""" step_partial = partial(step, tetrominoes, config=config, queue_fn=queue_fn) batched_step_fn = jax.jit( vmap( step_partial, in_axes=(0, 0), # Batch state and action out_axes=(0, 0, 0, 0, 0), ), static_argnames=["config"], ) return batched_step_fn(states, actions)
[docs] def batched_reset( tetrominoes: Tetrominoes, keys: chex.PRNGKey, # [B, 2] *, # Force config to be a keyword argument config: EnvConfig, create_queue_fn: CreateQueueFunction = create_bag_queue, queue_fn: QueueFunction = bag_queue_get_next_element, batch_size: int = 1, ) -> Tuple[chex.PRNGKey, chex.Array, State]: """Vectorized version of reset function that handles batches.""" # Create a partial function with static config reset_partial = partial( reset, tetrominoes, config=config, create_queue_fn=create_queue_fn, queue_fn=queue_fn, ) # Combine vmap and jit with static config batched_reset_fn = jax.jit( vmap( reset_partial, in_axes=(0,), # Batch only the key out_axes=(0, 0, 0), # Batch both outputs ), static_argnames=["config"], ) return batched_reset_fn(keys)
ACTION_ID_TO_NAME = { 0: "move_left", 1: "move_right", 2: "move_down", 3: "rotate_counterclockwise", 4: "rotate_clockwise", 5: "do_nothing", 6: "hard_drop", }