diff options
author | Thomas Letan <lthms@soap.coffee> | 2022-08-27 13:34:43 +0200 |
---|---|---|
committer | Thomas Letan <lthms@soap.coffee> | 2022-08-27 13:34:43 +0200 |
commit | 2bc1e193c1a59e78bcd51e93fb89a9336cd63698 (patch) | |
tree | 9a8c60ac584036d4643eccf66905393173775d0f | |
parent | Close Clap (diff) |
Make 'Socket' more exception aware
-rw-r--r-- | lib/mltp_ipc/socket.ml | 47 | ||||
-rw-r--r-- | lib/mltp_ipc/socket.mli | 36 | ||||
-rw-r--r-- | lib/sway_ipc/sway_ipc.ml | 41 | ||||
-rw-r--r-- | lib/sway_ipc/sway_ipc.mli | 2 |
4 files changed, 93 insertions, 33 deletions
diff --git a/lib/mltp_ipc/socket.ml b/lib/mltp_ipc/socket.ml index b846804..e85b7b2 100644 --- a/lib/mltp_ipc/socket.ml +++ b/lib/mltp_ipc/socket.ml @@ -3,39 +3,56 @@ * file, You can obtain one at https://mozilla.org/MPL/2.0/. *) type socket = Lwt_io.input_channel * Lwt_io.output_channel * Lwt_unix.file_descr +type error = Bad_magic_string of string | Connection_closed -let rec read_all ~count ((socket, _, _) as s) = +let connect socket_path : socket Lwt.t = let open Lwt.Syntax in - let* payload = Lwt_io.read ~count socket in - if String.length payload = count then Lwt.return payload + let socket = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in + let+ () = Lwt_unix.connect socket (ADDR_UNIX socket_path) in + let socket_in = Lwt_io.of_fd ~mode:Input socket in + let socket_out = Lwt_io.of_fd ~mode:Output socket in + (socket_in, socket_out, socket) + +let ( let*! ) x k = Lwt.bind x k +let close (_, _, s) = Lwt_unix.close s + +let catch_end_of_file f = + Lwt.try_bind f Lwt_result.return @@ function + | End_of_file -> Lwt_result.fail Connection_closed + | exn -> Lwt.fail exn + +let rec read_all ~count ((socket, _, _) as s) = + let open Lwt_result.Syntax in + let* payload = catch_end_of_file (fun () -> Lwt_io.read ~count socket) in + if String.length payload = count then Lwt_result.return payload else let+ rest = read_all ~count:(count - String.length payload) s in payload ^ rest let read_magic_string ~magic_string socket = - let open Lwt.Syntax in - let magic = magic_string in - let* msg = read_all ~count:(String.length magic) socket in - assert (msg = magic); - Lwt.return () + let open Lwt_result.Syntax in + let* msg = read_all ~count:(String.length magic_string) socket in + if msg <> magic_string then + let*! () = close socket in + Lwt_result.fail (Bad_magic_string msg) + else Lwt_result.return () let write_raw_message ~magic_string (_, socket, _) raw = let msg = Raw_message.to_string ~magic_string raw in - Lwt_io.write socket msg + catch_end_of_file @@ fun () -> Lwt_io.write socket msg let read_raw_message ~magic_string socket = - let open Lwt.Syntax in + let open Lwt_result.Syntax in let* () = read_magic_string ~magic_string socket in let* msg = read_all ~count:4 socket in let size = Raw_message.string_to_int32 msg in let* msg = read_all ~count:4 socket in let msg_type = Raw_message.string_to_int32 msg in let* payload = read_all ~count:(Int32.to_int size) socket in - Lwt.return (msg_type, payload) + Lwt_result.return (msg_type, payload) let rec read_next_raw_message ~magic_string socket f = - let open Lwt.Syntax in + let open Lwt_result.Syntax in let* raw = read_raw_message ~magic_string socket in - if f raw then Lwt.return raw else read_next_raw_message ~magic_string socket f - -let close (_, _, s) = Lwt_unix.close s + if f raw then Lwt_result.return raw + else read_next_raw_message ~magic_string socket f diff --git a/lib/mltp_ipc/socket.mli b/lib/mltp_ipc/socket.mli index 038473c..46dda86 100644 --- a/lib/mltp_ipc/socket.mli +++ b/lib/mltp_ipc/socket.mli @@ -13,17 +13,45 @@ type socket (** A socket to communicate with a peer using the so-called MTLP protocol. *) -val read_raw_message : magic_string:string -> socket -> Raw_message.t Lwt.t +val connect : string -> socket Lwt.t +(** Establish a bi-directional connection with a peer. *) + +val close : socket -> unit Lwt.t +(** Close a bi-directional connection with a peer. *) + +type error = + | Bad_magic_string of string + (** When trying to read a MTLP message, the magic string was not + correct. *) + | Connection_closed + (** When trying to receive from or send a message to a closed + bi-directional connection. *) + +val read_raw_message : + magic_string:string -> socket -> (Raw_message.t, error) result Lwt.t +(** [read_raw_message ~magic_string socket] reads a MTLP + message from [socket]. + + This function may fail with the following errors: + + {ul {li [Bad_magic_string] (closes [socket] when it happens)} + {li [Connection_closed]}} *) val read_next_raw_message : magic_string:string -> socket -> (Raw_message.t -> bool) -> - Raw_message.t Lwt.t + (Raw_message.t, error) result Lwt.t (** [read_next_raw_message ~magic_string socket f] returns the next raw message received by [socket] which satisfies [f]’s conditions. Messages that don’t satisfy [f]’s conditions are - ignored. *) + ignored. + + This function may fail with the following errors: + + {ul {li [Bad_magic_string] (closes [socket] when it happens)} + {li [Connection_closed]}} *) val write_raw_message : - magic_string:string -> socket -> Raw_message.t -> unit Lwt.t + magic_string:string -> socket -> Raw_message.t -> (unit, error) result Lwt.t +(** This function may fail with [Connection_closed]. *) diff --git a/lib/sway_ipc/sway_ipc.ml b/lib/sway_ipc/sway_ipc.ml index 72c0798..be33e9f 100644 --- a/lib/sway_ipc/sway_ipc.ml +++ b/lib/sway_ipc/sway_ipc.ml @@ -5,6 +5,8 @@ open Sway_ipc_types open Mltp_ipc +exception Sway_ipc_error of Socket.error + let magic_string = "i3-ipc" let sway_sock_path () = @@ -14,22 +16,26 @@ let sway_sock_path () = type socket = Socket.socket -let connect () : socket Lwt.t = - let open Lwt.Syntax in - let socket = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in - let+ () = Lwt_unix.connect socket (ADDR_UNIX (sway_sock_path ())) in - let socket_in = Lwt_io.of_fd ~mode:Input socket in - let socket_out = Lwt_io.of_fd ~mode:Output socket in - (socket_in, socket_out, socket) - +let connect () : socket Lwt.t = Socket.connect (sway_sock_path ()) let close socket = Socket.close socket +let trust_sway f = + let open Lwt.Syntax in + let* x = f () in + match x with Ok x -> Lwt.return x | Error e -> raise (Sway_ipc_error e) + let with_socket f = let open Lwt.Syntax in let* socket = connect () in - let* res = f socket in - let+ () = Socket.close socket in - res + Lwt.try_bind + (fun () -> + let* res = f socket in + let* () = Socket.close socket in + Lwt.return res) + Lwt.return + (fun exn -> + let* () = Socket.close socket in + raise exn) let socket_from_option = function | Some socket -> Lwt.return socket @@ -39,8 +45,12 @@ let send_command ?socket cmd = let open Lwt.Syntax in let* socket = socket_from_option socket in let ((op, _) as raw) = Message.to_raw_message cmd in - let* () = Socket.write_raw_message ~magic_string socket raw in - let* op', payload = Socket.read_raw_message ~magic_string socket in + let* () = + trust_sway @@ fun () -> Socket.write_raw_message ~magic_string socket raw + in + let* op', payload = + trust_sway @@ fun () -> Socket.read_raw_message ~magic_string socket + in assert (op = op'); Lwt.return @@ Json_decoder.of_string_exn (Message.reply_decoder cmd) payload @@ -50,13 +60,16 @@ let subscribe ?socket events = let+ { success } = send_command ~socket (Subscribe events) in if success then Lwt_stream.from (fun () -> + let open Lwt.Syntax in let+ ev = Socket.read_next_raw_message ~magic_string socket (fun (code, _) -> List.exists (fun ev_type -> ev_type = Event.event_type_of_code code) events) in - Some (Event.event_of_raw_message ev)) + match ev with + | Ok ev -> Some (Event.event_of_raw_message ev) + | Error _ -> None) else failwith "Something went wrong" let get_tree ?socket () = send_command ?socket Get_tree diff --git a/lib/sway_ipc/sway_ipc.mli b/lib/sway_ipc/sway_ipc.mli index f35d05b..d5d7e72 100644 --- a/lib/sway_ipc/sway_ipc.mli +++ b/lib/sway_ipc/sway_ipc.mli @@ -5,6 +5,8 @@ type socket (** A socket to interact with Sway. *) +exception Sway_ipc_error of Mltp_ipc.Socket.error + val connect : unit -> socket Lwt.t (** [connect ()] establishes a connection with Sway. This connection can be ended by using {!close}. |