diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index 20b07be..340d083 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -2398,12 +2398,6 @@ def __init__( kernel_init=kernel_init, bias_init=bias_init, ) - self._block_size_config = config.block_size - - @property - @override - def block_size(self): - return self._block_size_config @classmethod @override diff --git a/sequence_layers/mlx/attention_test.py b/sequence_layers/mlx/attention_test.py index cdebb47..18f296e 100644 --- a/sequence_layers/mlx/attention_test.py +++ b/sequence_layers/mlx/attention_test.py @@ -270,7 +270,7 @@ def test_from_config(self): mlx_layer, attention.LocalDotProductSelfAttention, ) - self.assertEqual(mlx_layer.block_size, 2) + self.assertEqual(mlx_layer.block_size, 1) x = test_utils.random_sequence(1, 8, 8) y = mlx_layer.layer(x, training=False) diff --git a/sequence_layers/specs/attention_behaviors.py b/sequence_layers/specs/attention_behaviors.py index 22bda20..0c00087 100644 --- a/sequence_layers/specs/attention_behaviors.py +++ b/sequence_layers/specs/attention_behaviors.py @@ -613,6 +613,7 @@ def test_layer_basic(self): x = self.random_sequence(batch_size, time, channels) layer = self.init_layer(layer, x) + self.assertEqual(layer.block_size, 1) self.assertEqual(layer.output_ratio, 1) self.assertEqual(layer.name, 'local_dot_product_self_attention') self.assertEqual( diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index 40a62bc..accc1ca 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -438,6 +438,43 @@ def create_steppable(self) -> types_spec.Steppable: class DefaultSteppable(DefaultTestLayer, backend_sl.types.Steppable): """Mock layer for testing.""" + @property + @override + def block_size(self) -> int: + return backend_sl.types.Steppable.block_size.fget(self) + + @property + @override + def output_ratio(self) -> fractions.Fraction: + return backend_sl.types.Steppable.output_ratio.fget(self) + + @property + @override + def supports_step(self) -> bool: + return backend_sl.types.Steppable.supports_step.fget(self) + + @property + @override + def input_latency(self) -> int: + return backend_sl.types.Steppable.input_latency.fget(self) + + @property + @override + def output_latency(self) -> int: + return backend_sl.types.Steppable.output_latency.fget(self) + + @override + def get_accumulated_input_latency(self, input_latency: int) -> int: + return backend_sl.types.Steppable.get_accumulated_input_latency( + self, input_latency + ) + + @override + def get_accumulated_output_latency(self, output_latency: int) -> int: + return backend_sl.types.Steppable.get_accumulated_output_latency( + self, output_latency + ) + @override def layer_with_emits(self, *args, **kwargs): return backend_sl.types.Steppable.layer_with_emits(