Source code for tetris_gymnasium.functional.core

"""Functional core for Tetris Gymnasium."""
from typing import NamedTuple, Tuple

import chex
import jax
import jax.numpy as jnp

from tetris_gymnasium.functional.tetrominoes import Tetrominoes, get_tetromino_matrix


[docs] class EnvConfig(NamedTuple): """Configuration for the Tetris environment. Attributes: width (int): The width of the game board. height (int): The height of the game board. padding (int): The padding around the game board. queue_size (int): The size of the tetromino queue. """ width: int height: int padding: int queue_size: int gravity_enabled: bool = True
[docs] @chex.dataclass class State: """State of the Tetris environment.""" rng_key: chex.PRNGKey board: chex.Array # [H, W] active_tetromino: chex.Array # [1] rotation: chex.Array # [1] x: chex.Array # [1] y: chex.Array # [1] queue: chex.Array # [L, 1] queue_index: chex.Array # [1] # holder: Optional[int] game_over: chex.Array # [1] score: chex.Array # [1]
# Utility functions
[docs] def create_board(config: EnvConfig, tetrominoes: Tetrominoes) -> chex.Array: """Creates an empty Tetris board with padding. Args: config: Environment configuration. tetrominoes: Tetrominoes object containing tetromino configurations. Returns: A 2D array representing the empty Tetris board with padding. """ empty_board = jnp.zeros((config.height, config.width), dtype=jnp.int8) padded_board = jnp.pad( empty_board, ((0, config.padding), (config.padding, config.padding)), mode="constant", constant_values=tetrominoes.base_pixels[1], ) return padded_board
def get_initial_x_y( config: EnvConfig, tetrominoes: Tetrominoes, active_tetromino: int ) -> Tuple[int, int]: """Calculates the initial x and y coordinates for a new tetromino. Args: config: Environment configuration. tetrominoes: Tetrominoes object containing tetromino configurations. active_tetromino: ID of the active tetromino. Returns: A tuple containing the initial x and y coordinates. """ x = (config.width + config.padding * 2) // 2 - get_tetromino_matrix( tetrominoes, active_tetromino, 0 ).shape[1] // 2 y = 0 return x, y
[docs] def collision(board: chex.Array, tetromino: chex.Array, x: int, y: int) -> chex.Array: """Checks if there's a collision between the tetromino and the board at the given position. Args: board: The current state of the Tetris board. tetromino: The tetromino to check for collision. x: The x-coordinate of the tetromino's position. y: The y-coordinate of the tetromino's position. Returns: A boolean indicating whether there's a collision. """ h, w = tetromino.shape board_section = jax.lax.dynamic_slice(board, (y, x), (h, w)) return jnp.any((board_section > 0) & (tetromino > 0))
def project_tetromino( board: chex.Array, tetromino: chex.Array, x: int, y: int, tetromino_id: int ) -> chex.Array: """Projects a tetromino onto the board at the given position. Args: board: The current state of the Tetris board. tetromino: The tetromino to project. x: The x-coordinate of the tetromino's position. y: The y-coordinate of the tetromino's position. tetromino_id: The ID of the tetromino. Returns: The updated board with the projected tetromino. """ update = jax.lax.dynamic_update_slice( jnp.zeros_like(board), tetromino * tetromino_id, (y, x) ) return board + update
[docs] def score(config: EnvConfig, rows_cleared: int) -> jnp.uint8: """Calculates the score based on the number of rows cleared. Args: config: Environment configuration. rows_cleared: The number of rows cleared. Returns: The calculated score as a uint8. """ standard_clears_reward = jax.lax.cond( rows_cleared > 0, lambda _: jnp.int32(rows_cleared * 200 - 100), lambda _: jnp.int32(0), operand=None, ) tetris_reward = jax.lax.cond( rows_cleared == 4, lambda _: jnp.int32(800), lambda _: standard_clears_reward, operand=None, ) return tetris_reward
# Core game logic functions def graviy_step( tetrominoes: Tetrominoes, board: chex.Array, active_tetromino: int, rotation: int, x: int, y: int, ) -> Tuple[chex.Array, int]: """Applies gravity to the active tetromino, moving it down if possible. Args: tetrominoes: Tetrominoes object containing tetromino configurations. board: The current state of the Tetris board. active_tetromino: ID of the active tetromino. rotation: Current rotation of the active tetromino. x: Current x-coordinate of the active tetromino. y: Current y-coordinate of the active tetromino. Returns: The new y-coordinate after applying gravity. """ new_y = jax.lax.cond( ~collision( board, get_tetromino_matrix(tetrominoes, active_tetromino, rotation), x, y + 1, ), lambda: y + 1, lambda: y, ) return new_y
[docs] def clear_filled_rows( config: EnvConfig, tetrominoes: Tetrominoes, board: chex.Array ) -> Tuple[chex.Array, chex.Array]: """Clears filled rows from the board and returns the updated board and number of cleared rows. Args: config: Environment configuration. tetrominoes: Tetrominoes object containing tetromino configurations. board: The current state of the Tetris board. Returns: A tuple containing the updated board and the number of cleared rows. """ sub_board = board[: -config.padding, config.padding : -config.padding] filled_rows_mask = jnp.all(sub_board > 0, axis=1) n_filled = jnp.sum(filled_rows_mask) def clear_rows(sub_board): indices = jnp.arange(config.height) uncleared_board_indices = jnp.where( filled_rows_mask, -config.height, indices ) # -config.height for invalid indices cleared_board_indices = jnp.sort( uncleared_board_indices ) # cleared rows to the top # Create new board by referencing the cleared rows cleared_sub_board = jnp.take( sub_board, cleared_board_indices, axis=0, fill_value=0 ) # Add padding to the cleared sub-board board = jnp.pad( cleared_sub_board, ((0, config.padding), (config.padding, config.padding)), mode="constant", constant_values=tetrominoes.base_pixels[1], ) return board board = jax.lax.cond(n_filled > 0, clear_rows, lambda x: board, sub_board) return board, n_filled
[docs] def hard_drop(board: chex.Array, tetromino: chex.Array, x: int, y: int) -> tuple: """Performs a hard drop of the tetromino, moving it down as far as possible. Args: board: The current state of the Tetris board. tetromino: The tetromino to drop. x: The x-coordinate of the tetromino's position. y: The y-coordinate of the tetromino's position. Returns: The final y-coordinate after the hard drop. """ def cond_fun(y): return ~collision(board, tetromino, x, y + 1) def body_fun(y): return y + 1 new_y = jax.lax.while_loop(cond_fun, body_fun, y) reward = 2 * (new_y - y) # 2 points per cell dropped return new_y, reward
[docs] def lock_active_tetromino( config: EnvConfig, tetrominoes: Tetrominoes, board, active_tetromino, rotation, x, y ) -> Tuple[chex.Array, chex.Array, chex.Array]: """Locks the active tetromino in place, clears any filled rows, and calculates the reward. Args: config: Environment configuration. tetrominoes: Tetrominoes object containing tetromino configurations. board: The current state of the Tetris board. active_tetromino: ID of the active tetromino. rotation: Current rotation of the active tetromino. x: Current x-coordinate of the active tetromino. y: Current y-coordinate of the active tetromino. Returns: A tuple containing the updated board and the reward. """ tetromino_matrix = get_tetromino_matrix(tetrominoes, active_tetromino, rotation) # place the tetromino on the board updated_board = project_tetromino( board, tetromino_matrix, x, y, tetrominoes.ids[active_tetromino] ) # clear filled rows updated_board, lines_cleared = clear_filled_rows(config, tetrominoes, updated_board) # calculate reward reward = score(config, lines_cleared) return updated_board, reward, lines_cleared
[docs] def check_game_over( tetrominoes: Tetrominoes, board, active_tetromino, rotation, x, y ) -> bool: """Checks if the game is over by determining if the new tetromino collides immediately. Args: tetrominoes: Tetrominoes object containing tetromino configurations. board: The current state of the Tetris board. active_tetromino: ID of the active tetromino. rotation: Current rotation of the active tetromino. x: Current x-coordinate of the active tetromino. y: Current y-coordinate of the active tetromino. Returns: A boolean indicating whether the game is over. """ return collision( board, get_tetromino_matrix(tetrominoes, active_tetromino, rotation), x, y )