This commit is contained in:
Swrup 2025-02-12 11:19:19 +01:00
commit 7f276c45a0
12 changed files with 673 additions and 0 deletions

3
src/dune Normal file
View file

@ -0,0 +1,3 @@
(library
(public_name purr_chacha)
(name purr_chacha))

284
src/poly1305.ml Normal file
View file

@ -0,0 +1,284 @@
(* adapted from the 64 bit version of:
https://github.com/floodyberry/poly1305-donna *)
open Int64
type int128 =
{ hi : Int64.t
; lo : Int64.t
}
let mul128 a b =
let a_lo = logand a 0xffffffff_L in
let a_hi = shift_right_logical a 32 in
let b_lo = logand b 0xffffffff_L in
let b_hi = shift_right_logical b 32 in
let lolo = mul a_lo b_lo in
let lohi = mul a_lo b_hi in
let hilo = mul a_hi b_lo in
let hihi = mul a_hi b_hi in
let c =
add
(add (shift_right_logical lolo 32) (logand lohi 0xffffffff_L))
(logand hilo 0xffffffff_L)
in
let lo = logor (logand lolo 0xffffffff_L) (shift_left c 32) in
let hi =
add
(add
(add hihi (shift_right_logical lohi 32))
(shift_right_logical hilo 32) )
(shift_right_logical c 32)
in
{ hi; lo }
let lt a b = if unsigned_compare a b < 0 then 1_L else 0_L
let add128 a b =
let t = a.lo in
let lo = add a.lo b.lo in
let hi = add (add b.hi (lt lo t)) a.hi in
{ hi; lo }
let addlo a b =
let t = a.lo in
let lo = add a.lo b in
let hi = add a.hi (lt lo t) in
{ hi; lo }
let shr a n = add (shift_right_logical a.lo n) (shift_left a.hi (64 - n))
type state =
{ r : Int64.t array
; h : Int64.t array
; pad : Int64.t array
; mutable leftover : int
; buffer : bytes
; mutable final : bool
; mac : bytes
}
let u8to64 s pos = String.get_int64_le s pos
let init key =
let t0 = u8to64 key 0 in
let t1 = u8to64 key 8 in
let r = Array.make 3 0x0_L in
r.(0) <- logand t0 0xffc0fffffff_L;
r.(1) <-
logand
(logor (shift_right_logical t0 44) (shift_left t1 20))
0xfffffc0ffff_L;
r.(2) <- logand (shift_right_logical t1 24) 0x00ffffffc0f_L;
let h = Array.make 3 0x0_L in
let pad = Array.make 2 0x0_L in
pad.(0) <- u8to64 key 16;
pad.(1) <- u8to64 key 24;
let leftover = 0 in
let buffer = Bytes.make 16 '\000' in
let final = true in
let mac = Bytes.make 16 '\000' in
{ r; h; pad; leftover; buffer; final; mac }
let blocks state s pos len =
let hibit : Int64.t = if state.final then shift_left 1_L 40 else 0_L in
let r = Array.copy state.r in
let h = Array.copy state.h in
let s1 = mul r.(1) (shift_left 5_L 2) in
let s2 = mul r.(2) (shift_left 5_L 2) in
let rec loop pos len =
if len < 16 then ()
else
(* h += m[i] *)
let t0 = u8to64 s (pos + 0) in
let t1 = u8to64 s (pos + 8) in
h.(0) <- add h.(0) (logand t0 0xfffffffffff_L);
h.(1) <-
add h.(1)
(logand
(logor (shift_right_logical t0 44) (shift_left t1 20))
0xfffffffffff_L );
h.(2) <-
add h.(2)
(logor (logand (shift_right_logical t1 24) 0x3ffffffffff_L) hibit);
(* h *= r *)
let d0 : int128 =
add128 (add128 (mul128 h.(0) r.(0)) (mul128 h.(1) s2)) (mul128 h.(2) s1)
in
let d1 =
add128
(add128 (mul128 h.(0) r.(1)) (mul128 h.(1) r.(0)))
(mul128 h.(2) s2)
in
let d2 =
add128
(add128 (mul128 h.(0) r.(2)) (mul128 h.(1) r.(1)))
(mul128 h.(2) r.(0))
in
(* (partial) h %= p *)
let c = shr d0 44 in
h.(0) <- logand d0.lo 0xfffffffffff_L;
let d1 = addlo d1 c in
let c = shr d1 44 in
h.(1) <- logand d1.lo 0xfffffffffff_L;
let d2 = addlo d2 c in
let c = shr d2 42 in
h.(2) <- logand d2.lo 0x3ffffffffff_L;
h.(0) <- add h.(0) (mul c 5_L);
let c = shift_right_logical h.(0) 44 in
h.(0) <- logand h.(0) 0xfffffffffff_L;
h.(1) <- add h.(1) c;
loop (pos + 16) (len - 16)
in
loop pos len;
state.h.(0) <- h.(0);
state.h.(1) <- h.(1);
state.h.(2) <- h.(2);
()
let finish state =
(* process the remaining block *)
if state.leftover <> 0 then (
let i = state.leftover in
Bytes.set state.buffer i '\001';
for i = i + 1 to 16 - 1 do
Bytes.set state.buffer i '\000'
done;
state.final <- false;
blocks state (Bytes.unsafe_to_string state.buffer) 0 16 );
(* fully carry h *)
let h = Array.copy state.h in
let c = shift_right_logical h.(1) 44 in
h.(1) <- logand h.(1) 0xfffffffffff_L;
h.(2) <- add h.(2) c;
let c = shift_right_logical h.(2) 42 in
h.(2) <- logand h.(2) 0x3ffffffffff_L;
h.(0) <- add h.(0) (mul c 5_L);
let c = shift_right_logical h.(0) 44 in
h.(0) <- logand h.(0) 0xfffffffffff_L;
h.(1) <- add h.(1) c;
let c = shift_right_logical h.(1) 44 in
h.(1) <- logand h.(1) 0xfffffffffff_L;
h.(2) <- add h.(2) c;
let c = shift_right_logical h.(2) 42 in
h.(2) <- logand h.(2) 0x3ffffffffff_L;
h.(0) <- add h.(0) (mul c 5_L);
let c = shift_right_logical h.(0) 44 in
h.(0) <- logand h.(0) 0xfffffffffff_L;
h.(1) <- add h.(1) c;
(* compute h + -p *)
let g = Array.make 3 0_L in
g.(0) <- add h.(0) 5_L;
let c = shift_right_logical g.(0) 44 in
g.(0) <- logand g.(0) 0xfffffffffff_L;
g.(1) <- add h.(1) c;
let c = shift_right_logical g.(1) 44 in
g.(1) <- logand g.(1) 0xfffffffffff_L;
g.(2) <- sub (add h.(2) c) (shift_left 1_L 42);
(* select h if h < p, or h + -p if h >= p *)
let c = sub (shift_right_logical g.(2) (64 - 1)) 1_L in
g.(0) <- logand g.(0) c;
g.(1) <- logand g.(1) c;
g.(2) <- logand g.(2) c;
let c = lognot c in
h.(0) <- logor (logand h.(0) c) g.(0);
h.(1) <- logor (logand h.(1) c) g.(1);
h.(2) <- logor (logand h.(2) c) g.(2);
(* h = (h + pad) *)
let t0 = state.pad.(0) in
let t1 = state.pad.(1) in
h.(0) <- add h.(0) (logand t0 0xfffffffffff_L);
let c = shift_right_logical h.(0) 44 in
h.(0) <- logand h.(0) 0xfffffffffff_L;
h.(1) <-
add h.(1)
(add
(logand
(logor (shift_right_logical t0 44) (shift_left t1 20))
0xfffffffffff_L )
c );
let c = shift_right_logical h.(1) 44 in
h.(1) <- logand h.(1) 0xfffffffffff_L;
h.(2) <-
add h.(2) (add (logand (shift_right_logical t1 24) 0xfffffffffff_L) c);
h.(2) <- logand h.(2) 0x3ffffffffff_L;
(* mac = h % (2^128) *)
h.(0) <- logor h.(0) (shift_left h.(1) 44);
h.(1) <- logor (shift_right_logical h.(1) 20) (shift_left h.(2) 24);
Bytes.set_int64_le state.mac 0 h.(0);
Bytes.set_int64_le state.mac 8 h.(1);
(* zero out the state *)
state.h.(0) <- 0_L;
state.h.(1) <- 0_L;
state.h.(2) <- 0_L;
state.r.(0) <- 0_L;
state.r.(1) <- 0_L;
state.r.(2) <- 0_L;
state.pad.(0) <- 0_L;
state.pad.(1) <- 0_L;
()
let update state s =
let pos = ref 0 in
let len = ref (String.length s) in
(* handle leftover *)
let return =
if state.leftover <> 0 then (
let want = 16 - state.leftover in
let want = if want > !len then !len else want in
String.blit s !pos state.buffer state.leftover want;
len := !len - want;
pos := !pos + want;
state.leftover <- state.leftover + want;
if state.leftover < 16 then true
else (
blocks state (Bytes.unsafe_to_string state.buffer) 0 16;
state.leftover <- 0;
false ) )
else false
in
if return then ()
else (
(* process full blocks *)
if !len >= 16 then (
let want = Int.logand !len (Int.lognot (16 - 1)) in
blocks state s !pos want;
pos := !pos + want;
len := !len - want );
(* store leftover *)
if !len <> 0 then (
String.blit s !pos state.buffer state.leftover !len;
state.leftover <- state.leftover + !len;
() );
() )
let mac ~key s =
if String.length key <> 32 then invalid_arg "key length must be 32 bytes";
let state = init key in
update state s;
finish state;
Bytes.unsafe_to_string state.mac

135
src/purr_chacha.ml Normal file
View file

@ -0,0 +1,135 @@
(* n-bit left rotation (<<<)
result is unspecified if n < 0 or n >= 32 *)
let rot_l_32 v n =
let open Int32 in
logor (shift_left v n) (shift_right_logical v (32 - n))
(* mutates chacha state [s] *)
let quarter_round s a b c d =
let open Int32 in
(* a += b; d ^= a; d <<<= 16; *)
s.(a) <- add s.(a) s.(b);
s.(d) <- logxor s.(d) s.(a);
s.(d) <- rot_l_32 s.(d) 16;
(* c += d; b ^= c; b <<<= 12; *)
s.(c) <- add s.(c) s.(d);
s.(b) <- logxor s.(b) s.(c);
s.(b) <- rot_l_32 s.(b) 12;
(* a += b; d ^= a; d <<<= 8; *)
s.(a) <- add s.(a) s.(b);
s.(d) <- logxor s.(d) s.(a);
s.(d) <- rot_l_32 s.(d) 8;
(* c += d; b ^= c; b <<<= 7; *)
s.(c) <- add s.(c) s.(d);
s.(b) <- logxor s.(b) s.(c);
s.(b) <- rot_l_32 s.(b) 7;
()
let init_state key block_counter nonce =
Array.concat
[ [| 0x61707865_l; 0x3320646e_l; 0x79622d32_l; 0x6b206574_l |]
; key
; [| block_counter |]
; nonce
]
let chacha20_block key block_counter nonce =
let s = init_state key block_counter nonce in
let w_s = Array.copy s in
for _i = 0 to 10 - 1 do
(* round 1. *)
quarter_round w_s 0 4 8 12;
quarter_round w_s 1 5 9 13;
quarter_round w_s 2 6 10 14;
quarter_round w_s 3 7 11 15;
(* round 2. *)
quarter_round w_s 0 5 10 15;
quarter_round w_s 1 6 11 12;
quarter_round w_s 2 7 8 13;
quarter_round w_s 3 4 9 14
done;
for i = 0 to 15 do
s.(i) <- Int32.add s.(i) w_s.(i)
done;
s
let serialize state =
let len = Array.length state in
let s = Bytes.create (4 * len) in
for i = 0 to len - 1 do
Bytes.set_int32_le s (i * 4) state.(i)
done;
Bytes.unsafe_to_string s
(* TODO check Sys.big_endian ? this is probably bugged *)
(* ! raw length must be divisible by 4 here
use padding with \x00 if needed
convert string to int32 array; in correct order
we receive key and nonce as a sequence of octets with
no particular structure; to read in little endian *)
let read raw =
let nb = String.length raw / 4 in
let arr = Array.init nb (fun i -> String.get_int32_le raw (i * 4)) in
arr
(* [xor a b] apply xor to two int32 array; up to len = len(a)
put result in a *)
let xor a b =
for i = 0 to Array.length a - 1 do
a.(i) <- Int32.logxor a.(i) b.(i)
done
let chacha20_encrypt ~key ~nonce ?initial_counter plaintext =
if String.length key <> 32 then invalid_arg "key length must be 32 bytes";
if String.length nonce <> 12 then invalid_arg "nonce length must be 12 bytes";
let key = read key in
let nonce = read nonce in
let initial_counter = Option.value ~default:0_l initial_counter in
let len = String.length plaintext in
let remaining_len = len mod 64 in
let encrypted_message = Bytes.create ((len / 64 * 64) + remaining_len) in
for j = 0 to (len / 64) - 1 do
let block_counter = Int32.add initial_counter (Int32.of_int j) in
let key_stream = chacha20_block key block_counter nonce in
let block = read (String.sub plaintext (j * 64) 64) in
xor block key_stream;
let s = serialize block in
Bytes.blit_string s 0 encrypted_message (j * 64) 64;
()
done;
if remaining_len <> 0 then (
let j = String.length plaintext / 64 in
let block_counter = Int32.add initial_counter (Int32.of_int j) in
let key_stream = chacha20_block key block_counter nonce in
let remaining_padded_plaintext =
let padded =
Bytes.make (remaining_len - ((remaining_len mod 4) - 4)) '\x00'
in
let s = String.sub plaintext (j * 64) remaining_len in
Bytes.blit_string s 0 padded 0 remaining_len;
Bytes.unsafe_to_string padded
in
let block = read remaining_padded_plaintext in
xor block key_stream;
let s = serialize block in
(* rm padding *)
let s = String.sub s 0 remaining_len in
Bytes.blit_string s 0 encrypted_message (j * 64) remaining_len;
() );
Bytes.unsafe_to_string encrypted_message
let poly1305_mac = Poly1305.mac
let poly1305_key_gen ~key ~nonce =
if String.length key <> 32 then invalid_arg "key length must be 32 bytes";
if String.length nonce <> 12 then invalid_arg "nonce length must be 12 bytes";
let key = read key in
let nonce = read nonce in
let counter = 0_l in
let block = chacha20_block key counter nonce in
let s = Bytes.create 32 in
for i = 0 to 8 - 1 do
Bytes.set_int32_le s (i * 4) block.(i)
done;
Bytes.unsafe_to_string s

21
src/purr_chacha.mli Normal file
View file

@ -0,0 +1,21 @@
(* Purr_chacha, a pure OCaml implementation of ChaCha20 stream cipher
https://datatracker.ietf.org/doc/html/rfc7539 *)
(* [chacha20_encrypt ~key ~nonce ~initial_counter s]
Is ChaCha20 ciphertext of [s]
[key] must be 32 bytes
[nonce] must be 12 bytes
[initial_counter] default to Int32.zero *)
val chacha20_encrypt :
key:string -> nonce:string -> ?initial_counter:int32 -> string -> string
(* [poly1305_mac ~key s ]
Is the poly1305 message authentication code of [s] with [key]
[key] must be 32 bytes *)
val poly1305_mac : key:string -> string -> string
(* [poly1305_key_gen ~key ~nonce ]
Is a one-time poly1305 key generated pseudorandomly with chacha20
[key] must be 32 bytes
[nonce] must be 12 bytes *)
val poly1305_key_gen : key:string -> nonce:string -> string