diff --git a/lib/postgrex/protocol.ex b/lib/postgrex/protocol.ex index dfd94939..8258c77a 100644 --- a/lib/postgrex/protocol.ex +++ b/lib/postgrex/protocol.ex @@ -350,7 +350,7 @@ defmodule Postgrex.Protocol do | {:error, %ArgumentError{} | Postgrex.Error.t(), state} | {:error, %DBConnection.TransactionError{}, state} | {:disconnect, %RuntimeError{}, state} - | {:disconnect, %DBConnection.ConnectionError{}, state} + | {:disconnect | :disconnect_and_retry, %DBConnection.ConnectionError{}, state} def handle_prepare(%Query{} = query, _, %{postgres: {_, _}} = s) do lock_error(s, :prepare, query) end @@ -365,15 +365,18 @@ defmodule Postgrex.Protocol do def handle_prepare(%Query{name: ""} = query, opts, s) do prepare = Keyword.get(opts, :postgrex_prepare, false) status = new_status(opts, prepare: prepare) + comment = Keyword.get(opts, :comment) - case prepare do - true -> - parse_describe_close(s, status, query) + result = + case prepare do + true -> + parse_describe_close(s, status, query) - false -> - comment = Keyword.get(opts, :comment) - parse_describe_flush(s, status, query, comment) - end + false -> + parse_describe_flush(s, status, query, comment) + end + + handle_disconnect_retry(result) end def handle_prepare(%Query{} = query, opts, %{queries: nil} = s) do @@ -395,8 +398,9 @@ defmodule Postgrex.Protocol do false -> close_parse_describe_flush(s, status, query, comment) end - with {:ok, query, s} <- result do - {:ok, query, %{s | messages: []}} + case result do + {:ok, query, s} -> {:ok, query, %{s | messages: []}} + other -> handle_disconnect_retry(other) end end end @@ -422,11 +426,14 @@ defmodule Postgrex.Protocol do | {:error, %ArgumentError{} | Postgrex.Error.t(), state} | {:error, %DBConnection.TransactionError{}, state} | {:disconnect, %RuntimeError{}, state} - | {:disconnect, %DBConnection.ConnectionError{}, state} + | {:disconnect | :disconnect_and_retry, %DBConnection.ConnectionError{}, state} def handle_execute(%Query{} = query, params, opts, s) do case Keyword.get(opts, :postgrex_copy, false) do - true -> handle_execute_copy(query, params, opts, s) - false -> handle_execute_result(query, params, opts, s) + true -> + handle_execute_copy(query, params, opts, s) + + false -> + handle_execute_result(query, params, opts, s) end end @@ -503,9 +510,10 @@ defmodule Postgrex.Protocol do {:ok, Postgrex.Result.t(), state} | {:error, %ArgumentError{} | Postgrex.Error.t(), state} | {:disconnect, %RuntimeError{}, state} - | {:disconnect, %DBConnection.ConnectionError{}, state} + | {:disconnect | :disconnect_and_retry, %DBConnection.ConnectionError{}, state} def handle_close(%Query{ref: ref} = query, opts, %{postgres: {_, ref}} = s) do - flushed_close(s, new_status(opts), query) + result = flushed_close(s, new_status(opts), query) + handle_disconnect_retry(result) end def handle_close(%Query{} = query, _, %{postgres: {_, _}} = s) do @@ -513,7 +521,8 @@ defmodule Postgrex.Protocol do end def handle_close(%Query{} = query, opts, s) do - close(s, new_status(opts), query) + result = close(s, new_status(opts), query) + handle_disconnect_retry(result) end @impl true @@ -582,7 +591,8 @@ defmodule Postgrex.Protocol do {:ok, Postgrex.Result.t(), state} | {DBConnection.status(), state} | {:disconnect, %RuntimeError{}, state} - | {:disconnect, %DBConnection.ConnectionError{} | Postgrex.Error.t(), state} + | {:disconnect | :disconnect_and_retry, + %DBConnection.ConnectionError{} | Postgrex.Error.t(), state} def handle_begin(_, %{postgres: {_, _}} = s) do lock_error(s, :begin) end @@ -591,7 +601,8 @@ defmodule Postgrex.Protocol do case Keyword.get(opts, :mode, :transaction) do :transaction when postgres == :idle -> statement = "BEGIN" - handle_transaction(statement, opts, s) + result = handle_transaction(statement, opts, s) + handle_disconnect_retry(result) :savepoint when postgres == :transaction -> statement = "SAVEPOINT postgrex_savepoint" @@ -2081,7 +2092,7 @@ defmodule Postgrex.Protocol do bind_execute_close(s, status, query, params) {error, _, _} = other when error in [:error, :disconnect] -> - other + handle_disconnect_retry(other) end end @@ -2093,7 +2104,7 @@ defmodule Postgrex.Protocol do bind_execute(s, status, query, params) {error, _, _} = other when error in [:error, :disconnect] -> - other + handle_disconnect_retry(other) end end @@ -2114,8 +2125,8 @@ defmodule Postgrex.Protocol do msg_sync() ] - with :ok <- msg_send(%{s | buffer: nil}, msgs, buffer), - {:ok, s, buffer} <- recv_bind(s, status, buffer), + with :ok <- msg_send(%{s | buffer: nil}, msgs, buffer) |> handle_disconnect_retry(), + {:ok, s, buffer} <- recv_bind(s, status, buffer) |> handle_disconnect_retry(), {:ok, result, s, buffer} <- recv_execute(s, status, query, buffer), {:ok, s, buffer} <- recv_close(s, status, buffer), {:ok, s} <- recv_ready(s, status, buffer) do @@ -2125,7 +2136,7 @@ defmodule Postgrex.Protocol do error_ready(s, status, err, buffer) |> maybe_disconnect() - {:disconnect, _err, _s} = disconnect -> + {_disconnect_or_retry, _err, _s} = disconnect -> disconnect end end @@ -2151,8 +2162,8 @@ defmodule Postgrex.Protocol do msg_sync() ] - with :ok <- msg_send(%{s | buffer: nil}, msgs, buffer), - {:ok, s, buffer} <- recv_bind(s, status, buffer), + with :ok <- msg_send(%{s | buffer: nil}, msgs, buffer) |> handle_disconnect_retry(), + {:ok, s, buffer} <- recv_bind(s, status, buffer) |> handle_disconnect_retry(), {:ok, result, s, buffer} <- recv_execute(s, status, query, buffer), {:ok, s} <- recv_ready(s, status, buffer) do {:ok, query, result, s} @@ -2163,7 +2174,7 @@ defmodule Postgrex.Protocol do error_ready(s, status, err, buffer) |> maybe_disconnect() - {:disconnect, _err, _s} = disconnect -> + {_disconnect_or_retry, _err, _s} = disconnect -> disconnect end end @@ -3391,7 +3402,13 @@ defmodule Postgrex.Protocol do end defp conn_error(mod, action, reason) when reason in @nonposix_errors do - conn_error("#{mod} #{action}: #{reason}") + msg = "#{mod} #{action}: #{reason}" + + if reason == :closed do + conn_error(msg, :closed) + else + conn_error(msg) + end end defp conn_error(:tcp, action, reason) do @@ -3404,6 +3421,10 @@ defmodule Postgrex.Protocol do conn_error("ssl #{action}: #{formatted_reason} - #{inspect(reason)}") end + defp conn_error(message, reason) do + DBConnection.ConnectionError.exception(message: message, reason: reason) + end + defp conn_error(message) do DBConnection.ConnectionError.exception(message) end @@ -3416,6 +3437,16 @@ defmodule Postgrex.Protocol do {:disconnect, err, %{s | buffer: buffer}} end + # This function is used in two ways: + # + # * When we know the operation is fully retriable, we invoke it at the top + # * When only part is retriable (such as bind in execute or begin in a transaction), + # we invoke it at the specific instructions + defp handle_disconnect_retry({:disconnect, %{reason: :closed} = err, s}), + do: {:disconnect_and_retry, err, s} + + defp handle_disconnect_retry(other), do: other + defp sync_recv(s, status, buffer) do %{postgres: postgres, transactions: transactions} = s diff --git a/mix.exs b/mix.exs index 284be8c1..4339b096 100644 --- a/mix.exs +++ b/mix.exs @@ -33,7 +33,7 @@ defmodule Postgrex.Mixfile do {:jason, "~> 1.0", optional: true}, {:table, "~> 0.1.0", optional: true}, {:decimal, "~> 1.5 or ~> 2.0"}, - {:db_connection, "~> 2.1"} + {:db_connection, github: "elixir-ecto/db_connection", branch: "master"} ] end diff --git a/mix.lock b/mix.lock index f1e4a656..a228a8e2 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "db_connection": {:hex, :db_connection, "2.7.0", "b99faa9291bb09892c7da373bb82cba59aefa9b36300f6145c5f201c7adf48ec", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "dcf08f31b2701f857dfc787fbad78223d61a32204f217f15e881dd93e4bdd3ff"}, + "db_connection": {:git, "https://github.com/elixir-ecto/db_connection.git", "ce227e06605b77540c6b27bc94acddaf4e7ae027", [branch: "master"]}, "decimal": {:hex, :decimal, "1.9.0", "83e8daf59631d632b171faabafb4a9f4242c514b0a06ba3df493951c08f64d07", [:mix], [], "hexpm", "b1f2343568eed6928f3e751cf2dffde95bfaa19dd95d09e8a9ea92ccfd6f7d85"}, "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, "ex_doc": {:hex, :ex_doc, "0.38.2", "504d25eef296b4dec3b8e33e810bc8b5344d565998cd83914ffe1b8503737c02", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "732f2d972e42c116a70802f9898c51b54916e542cc50968ac6980512ec90f42b"}, @@ -9,5 +9,5 @@ "makeup_erlang": {:hex, :makeup_erlang, "1.0.2", "03e1804074b3aa64d5fad7aa64601ed0fb395337b982d9bcf04029d68d51b6a7", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "af33ff7ef368d5893e4a267933e7744e46ce3cf1f61e2dccf53a111ed3aa3727"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, "table": {:hex, :table, "0.1.0", "f16104d717f960a623afb134a91339d40d8e11e0c96cfce54fee086b333e43f0", [:mix], [], "hexpm", "bf533d3606823ad8a7ee16f41941e5e6e0e42a20c4504cdf4cfabaaed1c8acb9"}, - "telemetry": {:hex, :telemetry, "1.0.0", "0f453a102cdf13d506b7c0ab158324c337c41f1cc7548f0bc0e130bbf0ae9452", [:rebar3], [], "hexpm", "73bc09fa59b4a0284efb4624335583c528e07ec9ae76aca96ea0673850aec57a"}, + "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, } diff --git a/test/query_test.exs b/test/query_test.exs index 5730204d..7e7f5ef9 100644 --- a/test/query_test.exs +++ b/test/query_test.exs @@ -1965,4 +1965,70 @@ defmodule QueryTest do Postgrex.execute!(context[:pid], "name", "postgrex") end end + + test "disconnect_and_retry with prepare" do + # Start new connection so we can retry on disconnect + opts = [database: "postgrex_test", backoff_min: 1, backoff_max: 1] + {:ok, pid} = P.start_link(opts) + + # Drop socket + disconnect(pid) + + # Assert preparation happens instead of returning error + assert {:ok, _} = P.prepare(pid, "42", "SELECT 42") + end + + test "disconnect_and_retry with transaction" do + # Start new connection so we can retry on disconnect + opts = [database: "postgrex_test", backoff_min: 1, backoff_max: 1] + {:ok, pid} = P.start_link(opts) + + # Drop socket + disconnect(pid) + + # Assert transaction happens instead of returning error + assert {:ok, _} = P.transaction(pid, fn conn -> P.query(conn, "SELECT 1", []) end) + end + + test "disconnect_and_retry with closing prepared statement" do + # Start new connection so we can retry on disconnect + opts = [database: "postgrex_test", backoff_min: 1, backoff_max: 1] + {:ok, pid} = P.start_link(opts) + + # Prepare query that we wil try to close after disconnecting + {:ok, query} = P.prepare(pid, "42", "SELECT 42") + + # Drop socket + disconnect(pid) + + # Assert close happens instead of returning error + assert :ok = P.close(pid, query) + end + + test "disconnect_and_retry on attempting execution of prepared statement" do + # Start new connection so we can retry on disconnect + opts = [database: "postgrex_test", backoff_min: 1, backoff_max: 1] + {:ok, pid} = P.start_link(opts) + + # Prepare query that we wil try to execute after disconnecting + {:ok, query} = P.prepare(pid, "42", "SELECT 42") + + # Drop socket + disconnect(pid) + + # Assert execute happens instead of returning error + assert {:ok, _, _} = P.execute(pid, query, []) + end + + defp disconnect(pid) do + sock = DBConnection.run(pid, &get_socket/1) + :gen_tcp.shutdown(sock, :read_write) + end + + defp get_socket(conn) do + {:pool_ref, _, _, _, holder, _} = conn.pool_ref + [{:conn, _, _, state, _, _, _, _}] = :ets.lookup(holder, :conn) + {:gen_tcp, sock} = state.sock + sock + end end