aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Letan <lthms@soap.coffee>2022-08-27 13:34:43 +0200
committerThomas Letan <lthms@soap.coffee>2022-08-27 13:34:43 +0200
commit2bc1e193c1a59e78bcd51e93fb89a9336cd63698 (patch)
tree9a8c60ac584036d4643eccf66905393173775d0f
parentClose Clap (diff)
Make 'Socket' more exception aware
-rw-r--r--lib/mltp_ipc/socket.ml47
-rw-r--r--lib/mltp_ipc/socket.mli36
-rw-r--r--lib/sway_ipc/sway_ipc.ml41
-rw-r--r--lib/sway_ipc/sway_ipc.mli2
4 files changed, 93 insertions, 33 deletions
diff --git a/lib/mltp_ipc/socket.ml b/lib/mltp_ipc/socket.ml
index b846804..e85b7b2 100644
--- a/lib/mltp_ipc/socket.ml
+++ b/lib/mltp_ipc/socket.ml
@@ -3,39 +3,56 @@
* file, You can obtain one at https://mozilla.org/MPL/2.0/. *)
type socket = Lwt_io.input_channel * Lwt_io.output_channel * Lwt_unix.file_descr
+type error = Bad_magic_string of string | Connection_closed
-let rec read_all ~count ((socket, _, _) as s) =
+let connect socket_path : socket Lwt.t =
let open Lwt.Syntax in
- let* payload = Lwt_io.read ~count socket in
- if String.length payload = count then Lwt.return payload
+ let socket = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in
+ let+ () = Lwt_unix.connect socket (ADDR_UNIX socket_path) in
+ let socket_in = Lwt_io.of_fd ~mode:Input socket in
+ let socket_out = Lwt_io.of_fd ~mode:Output socket in
+ (socket_in, socket_out, socket)
+
+let ( let*! ) x k = Lwt.bind x k
+let close (_, _, s) = Lwt_unix.close s
+
+let catch_end_of_file f =
+ Lwt.try_bind f Lwt_result.return @@ function
+ | End_of_file -> Lwt_result.fail Connection_closed
+ | exn -> Lwt.fail exn
+
+let rec read_all ~count ((socket, _, _) as s) =
+ let open Lwt_result.Syntax in
+ let* payload = catch_end_of_file (fun () -> Lwt_io.read ~count socket) in
+ if String.length payload = count then Lwt_result.return payload
else
let+ rest = read_all ~count:(count - String.length payload) s in
payload ^ rest
let read_magic_string ~magic_string socket =
- let open Lwt.Syntax in
- let magic = magic_string in
- let* msg = read_all ~count:(String.length magic) socket in
- assert (msg = magic);
- Lwt.return ()
+ let open Lwt_result.Syntax in
+ let* msg = read_all ~count:(String.length magic_string) socket in
+ if msg <> magic_string then
+ let*! () = close socket in
+ Lwt_result.fail (Bad_magic_string msg)
+ else Lwt_result.return ()
let write_raw_message ~magic_string (_, socket, _) raw =
let msg = Raw_message.to_string ~magic_string raw in
- Lwt_io.write socket msg
+ catch_end_of_file @@ fun () -> Lwt_io.write socket msg
let read_raw_message ~magic_string socket =
- let open Lwt.Syntax in
+ let open Lwt_result.Syntax in
let* () = read_magic_string ~magic_string socket in
let* msg = read_all ~count:4 socket in
let size = Raw_message.string_to_int32 msg in
let* msg = read_all ~count:4 socket in
let msg_type = Raw_message.string_to_int32 msg in
let* payload = read_all ~count:(Int32.to_int size) socket in
- Lwt.return (msg_type, payload)
+ Lwt_result.return (msg_type, payload)
let rec read_next_raw_message ~magic_string socket f =
- let open Lwt.Syntax in
+ let open Lwt_result.Syntax in
let* raw = read_raw_message ~magic_string socket in
- if f raw then Lwt.return raw else read_next_raw_message ~magic_string socket f
-
-let close (_, _, s) = Lwt_unix.close s
+ if f raw then Lwt_result.return raw
+ else read_next_raw_message ~magic_string socket f
diff --git a/lib/mltp_ipc/socket.mli b/lib/mltp_ipc/socket.mli
index 038473c..46dda86 100644
--- a/lib/mltp_ipc/socket.mli
+++ b/lib/mltp_ipc/socket.mli
@@ -13,17 +13,45 @@
type socket
(** A socket to communicate with a peer using the so-called MTLP protocol. *)
-val read_raw_message : magic_string:string -> socket -> Raw_message.t Lwt.t
+val connect : string -> socket Lwt.t
+(** Establish a bi-directional connection with a peer. *)
+
+val close : socket -> unit Lwt.t
+(** Close a bi-directional connection with a peer. *)
+
+type error =
+ | Bad_magic_string of string
+ (** When trying to read a MTLP message, the magic string was not
+ correct. *)
+ | Connection_closed
+ (** When trying to receive from or send a message to a closed
+ bi-directional connection. *)
+
+val read_raw_message :
+ magic_string:string -> socket -> (Raw_message.t, error) result Lwt.t
+(** [read_raw_message ~magic_string socket] reads a MTLP
+ message from [socket].
+
+ This function may fail with the following errors:
+
+ {ul {li [Bad_magic_string] (closes [socket] when it happens)}
+ {li [Connection_closed]}} *)
val read_next_raw_message :
magic_string:string ->
socket ->
(Raw_message.t -> bool) ->
- Raw_message.t Lwt.t
+ (Raw_message.t, error) result Lwt.t
(** [read_next_raw_message ~magic_string socket f] returns the next
raw message received by [socket] which satisfies [f]’s
conditions. Messages that don’t satisfy [f]’s conditions are
- ignored. *)
+ ignored.
+
+ This function may fail with the following errors:
+
+ {ul {li [Bad_magic_string] (closes [socket] when it happens)}
+ {li [Connection_closed]}} *)
val write_raw_message :
- magic_string:string -> socket -> Raw_message.t -> unit Lwt.t
+ magic_string:string -> socket -> Raw_message.t -> (unit, error) result Lwt.t
+(** This function may fail with [Connection_closed]. *)
diff --git a/lib/sway_ipc/sway_ipc.ml b/lib/sway_ipc/sway_ipc.ml
index 72c0798..be33e9f 100644
--- a/lib/sway_ipc/sway_ipc.ml
+++ b/lib/sway_ipc/sway_ipc.ml
@@ -5,6 +5,8 @@
open Sway_ipc_types
open Mltp_ipc
+exception Sway_ipc_error of Socket.error
+
let magic_string = "i3-ipc"
let sway_sock_path () =
@@ -14,22 +16,26 @@ let sway_sock_path () =
type socket = Socket.socket
-let connect () : socket Lwt.t =
- let open Lwt.Syntax in
- let socket = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in
- let+ () = Lwt_unix.connect socket (ADDR_UNIX (sway_sock_path ())) in
- let socket_in = Lwt_io.of_fd ~mode:Input socket in
- let socket_out = Lwt_io.of_fd ~mode:Output socket in
- (socket_in, socket_out, socket)
-
+let connect () : socket Lwt.t = Socket.connect (sway_sock_path ())
let close socket = Socket.close socket
+let trust_sway f =
+ let open Lwt.Syntax in
+ let* x = f () in
+ match x with Ok x -> Lwt.return x | Error e -> raise (Sway_ipc_error e)
+
let with_socket f =
let open Lwt.Syntax in
let* socket = connect () in
- let* res = f socket in
- let+ () = Socket.close socket in
- res
+ Lwt.try_bind
+ (fun () ->
+ let* res = f socket in
+ let* () = Socket.close socket in
+ Lwt.return res)
+ Lwt.return
+ (fun exn ->
+ let* () = Socket.close socket in
+ raise exn)
let socket_from_option = function
| Some socket -> Lwt.return socket
@@ -39,8 +45,12 @@ let send_command ?socket cmd =
let open Lwt.Syntax in
let* socket = socket_from_option socket in
let ((op, _) as raw) = Message.to_raw_message cmd in
- let* () = Socket.write_raw_message ~magic_string socket raw in
- let* op', payload = Socket.read_raw_message ~magic_string socket in
+ let* () =
+ trust_sway @@ fun () -> Socket.write_raw_message ~magic_string socket raw
+ in
+ let* op', payload =
+ trust_sway @@ fun () -> Socket.read_raw_message ~magic_string socket
+ in
assert (op = op');
Lwt.return @@ Json_decoder.of_string_exn (Message.reply_decoder cmd) payload
@@ -50,13 +60,16 @@ let subscribe ?socket events =
let+ { success } = send_command ~socket (Subscribe events) in
if success then
Lwt_stream.from (fun () ->
+ let open Lwt.Syntax in
let+ ev =
Socket.read_next_raw_message ~magic_string socket (fun (code, _) ->
List.exists
(fun ev_type -> ev_type = Event.event_type_of_code code)
events)
in
- Some (Event.event_of_raw_message ev))
+ match ev with
+ | Ok ev -> Some (Event.event_of_raw_message ev)
+ | Error _ -> None)
else failwith "Something went wrong"
let get_tree ?socket () = send_command ?socket Get_tree
diff --git a/lib/sway_ipc/sway_ipc.mli b/lib/sway_ipc/sway_ipc.mli
index f35d05b..d5d7e72 100644
--- a/lib/sway_ipc/sway_ipc.mli
+++ b/lib/sway_ipc/sway_ipc.mli
@@ -5,6 +5,8 @@
type socket
(** A socket to interact with Sway. *)
+exception Sway_ipc_error of Mltp_ipc.Socket.error
+
val connect : unit -> socket Lwt.t
(** [connect ()] establishes a connection with Sway. This connection
can be ended by using {!close}.