From e1cd86f6d696b90979cfd287b0797670a3e54462 Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Wed, 13 May 2026 18:16:17 +0200 Subject: [PATCH] feat(mysql): support load data local infile --- sqlx-core/src/fs.rs | 19 ++++++++ sqlx-mysql/src/connection/executor.rs | 17 ++++++- sqlx-mysql/src/connection/stream.rs | 40 ++++++++++++++++- .../src/protocol/response/local_infile.rs | 37 ++++++++++++++++ sqlx-mysql/src/protocol/response/mod.rs | 2 + tests/mysql/fixtures/load_data_infile.txt | 2 + tests/mysql/mysql.rs | 44 ++++++++++++++++++- 7 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 sqlx-mysql/src/protocol/response/local_infile.rs create mode 100644 tests/mysql/fixtures/load_data_infile.txt diff --git a/sqlx-core/src/fs.rs b/sqlx-core/src/fs.rs index 0993cbeec6..45ae7ad89c 100644 --- a/sqlx-core/src/fs.rs +++ b/sqlx-core/src/fs.rs @@ -94,3 +94,22 @@ impl ReadDir { } } } + +#[cfg(feature = "_rt-tokio")] +pub async fn open_file>(path: P) -> Result { + if rt::rt_tokio::available() { + return tokio::fs::File::open(path).await; + } + + rt::missing_rt(path); +} + +#[cfg(all(feature = "_rt-async-io", not(feature = "_rt-tokio")))] +pub async fn open_file>(path: P) -> Result { + async_fs::File::open(path).await +} + +#[cfg(all(not(feature = "_rt-async-io"), not(feature = "_rt-tokio")))] +pub async fn open_file>(path: P) -> Result { + rt::missing_rt(path) +} diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index ee59d03d0a..acb7f03a7b 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -5,7 +5,7 @@ use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; use crate::io::MySqlBufExt; use crate::logger::QueryLogger; -use crate::protocol::response::Status; +use crate::protocol::response::{LocalInfilePacket, Status}; use crate::protocol::statement::{ BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose, }; @@ -22,7 +22,9 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; use sqlx_core::column::{ColumnOrigin, TableColumn}; +use sqlx_core::fs::open_file; use sqlx_core::sql_str::SqlStr; +use std::path::PathBuf; use std::{pin::pin, sync::Arc}; impl MySqlConnection { @@ -209,6 +211,19 @@ impl MySqlConnection { return Ok(()); } + if packet[0] == 0xfb { + // LocalInfileRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html + let packet = packet.decode::()?; + let path = PathBuf::from(String::from_utf8_lossy(&packet.filename).into_owned()); + let file = open_file(&path).await.map_err(|_| err_protocol!("cannot open file {} for local infile request", path.display()))?; + + self.inner.stream.send_stream(file).await?; + + continue; + } + + // otherwise, this first packet is the start of the result-set metadata, *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row; diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index e6aa8b48c8..efc764b8a4 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -5,7 +5,7 @@ use bytes::{Buf, Bytes, BytesMut}; use crate::error::Error; use crate::io::MySqlBufExt; -use crate::io::{ProtocolDecode, ProtocolEncode}; +use crate::io::{AsyncRead, ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; @@ -43,7 +43,8 @@ impl MySqlStream { | Capabilities::MULTI_RESULTS | Capabilities::PLUGIN_AUTH | Capabilities::PS_MULTI_RESULTS - | Capabilities::SSL; + | Capabilities::SSL + | Capabilities::LOCAL_FILES; if options.database.is_some() { capabilities |= Capabilities::CONNECT_WITH_DB; @@ -108,6 +109,41 @@ impl MySqlStream { Ok(()) } + /// Send data from a stream to the database server as MySQL packets + /// + /// This is used to send data for a LOCAL INFILE query + pub(crate) async fn send_stream( + &mut self, + mut source: impl AsyncRead + Unpin, + ) -> Result<(), Error> { + loop { + let buf = self.socket.write_buffer_mut(); + + // Write the CopyData format code and reserve space for the length + sequence_id + // This is safe even if empty, since we always need to send an empty packet at the end + buf.put_slice(b"\0\0\0\0"); + + let read = buf.read_from(&mut source).await?; + let read32 = i32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?; + + // rewrite header (len + sequenceid) + let mut header = read32.to_le_bytes(); + header[3] = self.sequence_id; + self.sequence_id = self.sequence_id.wrapping_add(1); + + buf.get_mut()[..4].copy_from_slice(&header); + + self.socket.flush().await?; + + if read32 == 0 { + break; + } + } + + Ok(()) + } + pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where T: ProtocolEncode<'en, Capabilities>, diff --git a/sqlx-mysql/src/protocol/response/local_infile.rs b/sqlx-mysql/src/protocol/response/local_infile.rs new file mode 100644 index 0000000000..33d9d47e00 --- /dev/null +++ b/sqlx-mysql/src/protocol/response/local_infile.rs @@ -0,0 +1,37 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::io::{BufExt, ProtocolDecode}; + +use crate::error::Error; + +/// Requests the client to send a file to the server, following a LOCAL INFILE statement +/// +/// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html +#[derive(Debug)] +pub struct LocalInfilePacket { + pub filename: Vec, +} + +impl ProtocolDecode<'_> for LocalInfilePacket { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0xfb { + return Err(err_protocol!( + "expected 0xfb (LocalInfileRequest) but found 0x{:02x}", + header + )); + } + + let filename = buf.get_bytes(buf.len()).to_vec(); + + Ok(Self { filename }) + } +} + +#[test] +fn test_decode_localinfile_packet() { + const DATA: &[u8] = b"\xfb\x64\x75\x6d\x6d\x79"; + + let p = LocalInfilePacket::decode(DATA.into()).unwrap(); + + assert_eq!(p.filename, b"dummy"); +} diff --git a/sqlx-mysql/src/protocol/response/mod.rs b/sqlx-mysql/src/protocol/response/mod.rs index 79767dc602..8d14f993e7 100644 --- a/sqlx-mysql/src/protocol/response/mod.rs +++ b/sqlx-mysql/src/protocol/response/mod.rs @@ -5,10 +5,12 @@ mod eof; mod err; +mod local_infile; mod ok; mod status; pub use eof::EofPacket; pub use err::ErrPacket; +pub use local_infile::LocalInfilePacket; pub use ok::OkPacket; pub use status::Status; diff --git a/tests/mysql/fixtures/load_data_infile.txt b/tests/mysql/fixtures/load_data_infile.txt new file mode 100644 index 0000000000..cdb2174598 --- /dev/null +++ b/tests/mysql/fixtures/load_data_infile.txt @@ -0,0 +1,2 @@ +1,a +2,b diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 5374e651c8..3f1a4de41b 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -1,7 +1,7 @@ use anyhow::Context; use futures_util::TryStreamExt; use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow}; -use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; +use sqlx::{AssertSqlSafe, Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; use sqlx_core::connection::ConnectOptions; use sqlx_core::types::Type; use sqlx_mysql::MySqlConnectOptions; @@ -599,7 +599,7 @@ async fn select_statement_count(conn: &mut MySqlConnection) -> Result anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_can_load_a_file() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT); + "#, + ) + .await?; + + let _ = conn.execute("SET GLOBAL local_infile = 1;").await?; + + let file_path = env::current_dir() + .unwrap() + .join("tests/mysql/fixtures/load_data_infile.txt"); + + // Execute LOAD DATA LOCAL INFILE + let load_query = format!( + "LOAD DATA LOCAL INFILE '{}' INTO TABLE users FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n'", + file_path.display() + ); + + let result = conn.execute(AssertSqlSafe(load_query)).await; + + if let Err(e) = result { + assert!(false, "{:?}", e) + } + + let name = sqlx::query("SELECT name FROM users WHERE id = 1") + .try_map(|row: MySqlRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!("a", name); + + Ok(()) +}