From 19c6c8aae71115b632c94c322099be16345a19e5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 17:33:55 -0400 Subject: [PATCH 1/5] fixes --- fast_llm/layers/ssm/gdn.py | 2 +- fast_llm_external_models/apriel2/modeling_apriel2.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index f694d80a6..cf5bc0bc4 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -227,7 +227,7 @@ def __init__( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - if _fast_gdn_available: + if _fast_gdn_available and distributed_config.use_cuda: self.chunk_gated_delta_rule = chunk_gated_delta_rule else: logger.warning( diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index ea0611953..9e82dfc4f 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -2839,7 +2839,7 @@ def forward( # Reshape back to [batch, num_patches, text_hidden] image_features = image_features.squeeze(0).view(batch_size, num_patches_per_image, -1) - return image_features, (*all_hidden_states, hidden_states, image_features) + return image_features, (*all_hidden_states, hidden_states, image_features) if output_hidden_states else None class SimpleMLP(nn.Module): diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42802f1c7..3e6910b6f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -802,7 +802,7 @@ def update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), - requires_cuda=False, + requires_cuda=True, # GDN available on CPU, but not in the converted model (also runs very slow). ) _gdn_block = MODEL_CONFIGS["apriel2_gdn"].config_dict["model"]["base_model"]["decoder"]["block"] From 1b6fcd01864263bb408d9f467c1529ac6fb4d2ad Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 17:39:32 -0400 Subject: [PATCH 2/5] fix --- tests/utils/distributed_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index b085f0994..933ea8f8e 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -70,7 +70,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon if torch.cuda.is_available() else { (None, "norm"): get_config(ignore_tensors=True), - (None, "word_embeddings_weight"): get_config(8e-2, 1e-4), + (None, "embeddings_weight"): get_config(8e-2, 1e-4), } ), (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(2e-2, 2e-3), From 573c6d84e7be8bdcfa39d6c4aa2a4bccde807f83 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 19:59:21 -0400 Subject: [PATCH 3/5] fix --- fast_llm/data/dataset/streaming.py | 3 +++ tests/models/test_streaming.py | 3 ++- tests/utils/redis.py | 2 -- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 8835612ec..e3fce4eb3 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -1,5 +1,6 @@ import functools import json +import logging import time import typing @@ -14,6 +15,8 @@ from fast_llm.data.document.token_data import TokenDataDocument from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + @config_class() class RedisStreamingDocumentData(Config): diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index 7b39a62f2..0c40f0a48 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -132,7 +132,7 @@ def _run_model_streaming_configs( model_testing_config, None, updates={ - ("data", "datasets"): {"training": {"port": port}}, + ("data", "datasets"): {"training": {"port": port, "timeout": 1.0}}, ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, "callbacks": { "streaming": { @@ -143,6 +143,7 @@ def _run_model_streaming_configs( "external_world_size": config.consumer_count, }, "export": {"format": model_testing_config.checkpoint_format.name}, + "timeout": 1.0, } }, # Disable tensor logging. diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 8160ef8c0..2dc09bee2 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -66,8 +66,6 @@ def producer_loop(): @contextlib.contextmanager def fake_redis_server(config: RedisConfig): - # We search for free port as port from previous test can still be not free even after server shutdown - # ----- Monkey-patch handler to suppress broken pipes ----- orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle From 3658c028017ad4791209e33fdd82c4ef42347bbe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 20:29:56 -0400 Subject: [PATCH 4/5] fix --- fast_llm_external_models/tests/test_apriel2/test_equivalence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index c5268f23c..8734aa02c 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -481,7 +481,7 @@ def test_batch_processing_behavior(self, model_pair): with torch.no_grad(): # Batch processing batch_src = get_pixtral_vision_features(source, pixel_values) - batch_tgt, _ = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1]) + batch_tgt = target.get_image_features(pixel_values)[0].view(-1, batch_src.shape[-1]) # Sequential processing singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)] From 68b68b2cd7ebc98d0c1e5ab04181d559cd7e5aff Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 25 Mar 2026 16:23:13 +0000 Subject: [PATCH 5/5] Fix intermittent test_data_streaming failure with fakeredis 2.34+ fakeredis 2.34 introduced Resp3Writer hardcoded for all TCP connections regardless of protocol negotiation. When XREADGROUP BLOCK times out on an empty stream, Resp3Writer.dump(None) sends RESP3 null (b'_\r\n'). The redis-py RESP2 parser (used by default) raises Protocol Error: b'_'. Fix: monkey-patch TCPFakeRequestHandler.setup in fake_redis_server() to replace Resp3Writer with Resp2Writer, restoring correct RESP2 null encoding (b'*-1\r\n') for blocking timeouts. The patch is guarded on the presence of Resp3Writer (2.34+ only) and raises explicitly if Resp2Writer is missing so future breakage is immediately diagnosable. --- tests/utils/redis.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 2dc09bee2..6004425dc 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -81,6 +81,34 @@ def safe_handle(self): fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle + # ----- Monkey-patch setup to use Resp2Writer instead of Resp3Writer ----- + # fakeredis 2.34+ hardcodes Resp3Writer for all connections, causing blocked + # XREADGROUP timeouts to return RESP3 null (b'_\r\n') even on RESP2 connections + # (i.e. clients that never sent HELLO 3). The redis-py RESP2 parser raises + # Protocol Error: b'_' on this byte. Fix: replace with Resp2Writer at setup time. + # The Resp2Writer class was introduced alongside the bug in 2.34, so use its + # presence as the version guard. + orig_setup = fakeredis._tcp_server.TCPFakeRequestHandler.setup + if hasattr(fakeredis._tcp_server, "Resp3Writer"): + # fakeredis 2.34+ hardcodes Resp3Writer for all connections, causing blocked + # XREADGROUP timeouts to return RESP3 null (b'_\r\n') even on RESP2 connections + # (i.e. clients that never sent HELLO 3). The redis-py RESP2 parser raises + # Protocol Error: b'_' on this byte. Fix: replace with Resp2Writer at setup time. + if not hasattr(fakeredis._tcp_server, "Resp2Writer"): + raise RuntimeError( + f"fakeredis {fakeredis.__version__} has Resp3Writer but not Resp2Writer — " + "the workaround for the RESP2/RESP3 null encoding bug no longer applies. " + "See tests/utils/redis.py for details." + ) + + def resp2_setup(self): + orig_setup(self) + if not isinstance(self.writer, fakeredis._tcp_server.Resp2Writer): + self.writer = fakeredis._tcp_server.Resp2Writer(self.client_address, self.wfile, self) + self.current_client.writer = self.writer + + fakeredis._tcp_server.TCPFakeRequestHandler.setup = resp2_setup + server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis") print(f"Starting fake redis server at {config.host}:{config.port}") thread = threading.Thread(target=server.serve_forever, daemon=True) @@ -94,3 +122,5 @@ def safe_handle(self): server.shutdown() server.server_close() thread.join() + fakeredis._tcp_server.TCPFakeRequestHandler.setup = orig_setup + fakeredis._tcp_server.TCPFakeRequestHandler.handle = orig_handle