{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}

module Crypto.Nettle.Ciphers.Internal
	( NettleCipher(..)
	, NettleBlockCipher(..)
	, NettleStreamCipher(..)
	, NettleBlockedStreamCipher(..)
	, NettleGCM
	, nettle_cipherInit
	, nettle_cipherInit'
	, nettle_ecbEncrypt
	, nettle_ecbDecrypt
	, nettle_cbcEncrypt
	, nettle_cbcDecrypt
	, nettle_cfbEncrypt
	, nettle_cfbDecrypt
	, nettle_ctrCombine
	, nettle_streamCombine
	, nettle_streamSetNonce
	, nettle_blockedStreamCombine
	, nettle_blockedStreamSetNonce
	, nettle_gcm_aeadInit
	, nettle_gcm_aeadStateAppendHeader
	, nettle_gcm_aeadStateEncrypt
	, nettle_gcm_aeadStateDecrypt
	, nettle_gcm_aeadStateFinalize
	) where

import Crypto.Cipher.Types as T

import Data.Tagged
import Data.Byteable (Byteable(..))
import Data.SecureMem
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import Data.Bits (xor)

import Nettle.Utils
import Crypto.Nettle.Ciphers.ForeignImports

-- internal functions are not camelCase on purpose
{-# ANN module "HLint: ignore Use camelCase" #-}

class NettleCipher c where
	-- | pointer to new context, key length, (const) key pointer
	nc_cipherInit    :: Tagged c (Ptr Word8 -> Word -> Ptr Word8 -> IO())
	nc_cipherName    :: Tagged c String
	nc_cipherKeySize :: Tagged c T.KeySizeSpecifier
	nc_ctx_size      :: Tagged c Int
	nc_ctx           :: c -> SecureMem
	nc_Ctx           :: SecureMem -> c
class NettleCipher c => NettleBlockCipher c where
	nbc_blockSize          :: Tagged c Int
	nbc_encrypt_ctx_offset :: Tagged c (Ptr Word8 -> Ptr Word8)
	nbc_encrypt_ctx_offset = Tagged id
	nbc_decrypt_ctx_offset :: Tagged c (Ptr Word8 -> Ptr Word8)
	nbc_decrypt_ctx_offset = Tagged id
	nbc_ecb_encrypt        :: Tagged c NettleCryptFunc
	nbc_ecb_decrypt        :: Tagged c NettleCryptFunc
	nbc_fun_encrypt        :: Tagged c (FunPtr NettleCryptFunc)
	nbc_fun_decrypt        :: Tagged c (FunPtr NettleCryptFunc)
class NettleCipher c => NettleStreamCipher c where
	nsc_streamCombine      :: Tagged c NettleCryptFunc
	nsc_nonceSize          :: Tagged c T.KeySizeSpecifier
	nsc_nonceSize          = Tagged $ T.KeySizeEnum []
	nsc_setNonce           :: Tagged c (Maybe (Ptr Word8 -> Word -> Ptr Word8 -> IO ()))
	nsc_setNonce           = Tagged Nothing

-- stream cipher based on generating (large) blocks to XOR with input,
-- but don't keep incomplete blocks in the state, so we have to do that here
class NettleCipher c => NettleBlockedStreamCipher c where
	nbsc_blockSize          :: Tagged c Int
	-- set new incomplete state
	nbsc_IncompleteState    :: c -> B.ByteString -> c
	nbsc_incompleteState    :: c -> B.ByteString
	nbsc_streamCombine      :: Tagged c NettleCryptFunc
	nbsc_nonceSize          :: Tagged c T.KeySizeSpecifier
	nbsc_nonceSize          = Tagged $ T.KeySizeEnum []
	nbsc_setNonce           :: Tagged c (Maybe (Ptr Word8 -> Word -> Ptr Word8 -> IO ()))
	nbsc_setNonce           = Tagged Nothing

nettle_cipherInit :: NettleCipher c => Key c -> c
nettle_cipherInit k = let ctx = nettle_cipherInit' (nc_cipherInit `witness` ctx) k in ctx

nettle_cipherInit' :: NettleCipher c => (Ptr Word8 -> Word -> Ptr Word8 -> IO()) -> Key c -> c
nettle_cipherInit' f k = let ctx = nc_Ctx $ key_init f (nc_ctx_size `witness` ctx) k in ctx

assert_blockSize :: NettleBlockCipher c => c -> B.ByteString -> a -> a
assert_blockSize c src result = if 0 /= B.length src `mod` (nbc_blockSize `witness` c) then error "input not a multiple of blockSize" else result

nettle_ecbEncrypt :: NettleBlockCipher c => c -> B.ByteString -> B.ByteString
nettle_ecbEncrypt c    src = assert_blockSize c src $ c_run_crypt   (nbc_encrypt_ctx_offset `witness` c)               (nbc_ecb_encrypt `witness` c) (nc_ctx c) src
nettle_ecbDecrypt :: NettleBlockCipher c => c -> B.ByteString -> B.ByteString
nettle_ecbDecrypt c    src = assert_blockSize c src $ c_run_crypt   (nbc_decrypt_ctx_offset `witness` c)               (nbc_ecb_decrypt `witness` c) (nc_ctx c) src
nettle_cbcEncrypt :: NettleBlockCipher c => c -> IV c -> B.ByteString -> B.ByteString
nettle_cbcEncrypt c iv src = assert_blockSize c src $ blockmode_run (nbc_encrypt_ctx_offset `witness` c) c_cbc_encrypt (nbc_fun_encrypt `witness` c) (nc_ctx c) iv src
nettle_cbcDecrypt :: NettleBlockCipher c => c -> IV c -> B.ByteString -> B.ByteString
nettle_cbcDecrypt c iv src = assert_blockSize c src $ blockmode_run (nbc_decrypt_ctx_offset `witness` c) c_cbc_decrypt (nbc_fun_decrypt `witness` c) (nc_ctx c) iv src
nettle_cfbEncrypt :: NettleBlockCipher c => c -> IV c -> B.ByteString -> B.ByteString
nettle_cfbEncrypt c iv src = assert_blockSize c src $ blockmode_run (nbc_encrypt_ctx_offset `witness` c) c_cfb_encrypt (nbc_fun_encrypt `witness` c) (nc_ctx c) iv src
nettle_cfbDecrypt :: NettleBlockCipher c => c -> IV c -> B.ByteString -> B.ByteString
nettle_cfbDecrypt c iv src = assert_blockSize c src $ blockmode_run (nbc_encrypt_ctx_offset `witness` c) c_cfb_decrypt (nbc_fun_encrypt `witness` c) (nc_ctx c) iv src
nettle_ctrCombine :: NettleBlockCipher c => c -> IV c -> B.ByteString -> B.ByteString
nettle_ctrCombine c        =                          blockmode_run (nbc_encrypt_ctx_offset `witness` c) c_ctr_crypt   (nbc_fun_encrypt `witness` c) (nc_ctx c)

nettle_streamCombine :: NettleStreamCipher c => c -> B.ByteString -> (B.ByteString, c)
nettle_streamCombine c indata = let (r, c') = stream_crypt (nsc_streamCombine `witness` c) (nc_ctx c) indata in (r, nc_Ctx c')
nettle_streamSetNonce :: NettleStreamCipher c => c -> B.ByteString -> Maybe c
nettle_streamSetNonce c nonce = case nsc_setNonce `witness` c of
	Nothing -> Nothing
	Just setnonce -> unsafeDupablePerformIO $
		secureMemCopy (nc_ctx c) >>= \ctx' ->
		withSecureMemPtr ctx' $ \ctxptr ->
		withByteStringPtr nonce $ \noncelen nonceptr ->
		setnonce ctxptr noncelen nonceptr >>
		return (Just $ nc_Ctx ctx')

nettle_blockedStreamCombine :: NettleBlockedStreamCipher c => c -> B.ByteString -> (B.ByteString, c)
nettle_blockedStreamCombine c indata = if B.length indata == 0 then (indata, c) else
	let inc = nbsc_incompleteState c; blocksiz = nbsc_blockSize `witness` c in
	if B.length inc /= 0
		then let
			-- first xor remaining block, then combine the rest
			(i1, i2) = B.splitAt (B.length inc) indata
			(inc1, inc2) = B.splitAt (B.length indata) inc
			r1 = B.pack $ B.zipWith xor i1 inc1
			c' = if B.length inc2 == 0 then nc_Ctx $ nc_ctx c else nbsc_IncompleteState c inc2
			(r, c'') = nettle_blockedStreamCombine c' i2
			in (B.append r1 r, c'')
		else if B.length indata `mod` blocksiz /= 0
			then let
				padding = B.replicate (blocksiz - (B.length indata `mod` blocksiz)) 0
				(r', c') = stream_crypt (nbsc_streamCombine `witness` c) (nc_ctx c) (B.append indata padding)
				(r, inc') = B.splitAt (B.length indata) r'
				in (r, nbsc_IncompleteState (nc_Ctx c') inc')
			else
				let (r, c') = stream_crypt (nbsc_streamCombine `witness` c) (nc_ctx c) indata in (r, nc_Ctx c')
nettle_blockedStreamSetNonce :: NettleBlockedStreamCipher c => c -> B.ByteString -> Maybe c
nettle_blockedStreamSetNonce c nonce = case nbsc_setNonce `witness` c of
	Nothing -> Nothing
	Just setnonce -> unsafeDupablePerformIO $
		secureMemCopy (nc_ctx c) >>= \ctx' ->
		withSecureMemPtr ctx' $ \ctxptr ->
		withByteStringPtr nonce $ \noncelen nonceptr ->
		setnonce ctxptr noncelen nonceptr >>
		return (Just $ nc_Ctx ctx')


nettle_gcm_aeadInit              :: (NettleBlockCipher c, AEADModeImpl c NettleGCM, Byteable iv) => c -> iv -> Maybe (AEAD c)
nettle_gcm_aeadInit          c  iv = if nbc_blockSize `witness` c == 16 then Just $ AEAD c $ AEADState $ gcm_init (nbc_encrypt_ctx_offset `witness` c) (nbc_fun_encrypt `witness` c) (nc_ctx c) iv else Nothing
nettle_gcm_aeadStateAppendHeader :: t -> NettleGCM -> B.ByteString -> NettleGCM
nettle_gcm_aeadStateAppendHeader _ = gcm_update
nettle_gcm_aeadStateEncrypt      :: NettleBlockCipher c => c -> NettleGCM -> B.ByteString -> (B.ByteString, NettleGCM)
nettle_gcm_aeadStateEncrypt      c = gcm_crypt c_gcm_encrypt (nbc_encrypt_ctx_offset `witness` c) (nbc_fun_encrypt `witness` c) (nc_ctx c)
nettle_gcm_aeadStateDecrypt      :: NettleBlockCipher c => c -> NettleGCM -> B.ByteString -> (B.ByteString, NettleGCM)
nettle_gcm_aeadStateDecrypt      c = gcm_crypt c_gcm_decrypt (nbc_encrypt_ctx_offset `witness` c) (nbc_fun_encrypt `witness` c) (nc_ctx c)
nettle_gcm_aeadStateFinalize     :: NettleBlockCipher c => c -> NettleGCM -> Int -> AuthTag
nettle_gcm_aeadStateFinalize     c = gcm_digest              (nbc_encrypt_ctx_offset `witness` c) (nbc_fun_encrypt `witness` c) (nc_ctx c)



key_init
	:: ToSecureMem k
	=> (Ptr Word8 -> Word -> Ptr Word8 -> IO ())
	-> Int -> k -> SecureMem
key_init initfun size k = unsafeCreateSecureMem size $ \ctxptr ->
	withSecureMemPtrSz (toSecureMem k) $ \ksize kptr -> initfun ctxptr (fromIntegral ksize) kptr

-- run encryption/decryption with same length for in and output
c_run_crypt
	:: (Ptr Word8 -> Ptr Word8)
	-> NettleCryptFunc
	-> SecureMem -> B.ByteString -> B.ByteString
c_run_crypt ctxoffset cfun ctx indata = unsafeDupablePerformIO $ withSecureMemPtr ctx $ \ctxptr ->
	withByteStringPtr indata $ \indatalen indataptr ->
	B.create (B.length indata) $ \outptr ->
	cfun (ctxoffset ctxptr) indatalen outptr indataptr

blockmode_run
	:: (Byteable iv)
	=> (Ptr Word8 -> Ptr Word8)
	-> NettleBlockMode
	-> FunPtr NettleCryptFunc
	-> SecureMem -> iv -> B.ByteString -> B.ByteString
blockmode_run ctxoffset mode crypt ctx iv indata = unsafeDupablePerformIO $ withSecureMemPtr ctx $ \ctxptr ->
	withByteStringPtr indata $ \indatalen indataptr ->
	withSecureMemPtrSz (toSecureMem $ toBytes iv) $ \ivlen ivptr -> -- copy IV, may get modified
	B.create (B.length indata) $ \outptr ->
	mode (ctxoffset ctxptr) crypt (fromIntegral ivlen) ivptr indatalen outptr indataptr

data NettleGCM = NettleGCM !SecureMem !SecureMem

gcm_init
	:: (Byteable iv)
	=> (Ptr Word8 -> Ptr Word8)
	-> FunPtr NettleCryptFunc
	-> SecureMem -> iv -> NettleGCM
gcm_init encctxoffset encrypt encctx iv = unsafeDupablePerformIO $
	withBytePtr iv $ \ivptr ->
	withSecureMemPtr encctx $ \encctxptr -> do
	h <- createSecureMem c_gcm_key_size $ \hptr ->
		c_gcm_set_key hptr (encctxoffset encctxptr) encrypt
	withSecureMemPtr h $ \hptr -> do
	ctx <- createSecureMem c_gcm_ctx_size $ \ctxptr ->
		c_gcm_set_iv ctxptr hptr (fromIntegral $ byteableLength iv) ivptr
	return (NettleGCM ctx h)

-- independent of cipher
gcm_update
	:: NettleGCM -> B.ByteString -> NettleGCM
gcm_update (NettleGCM ctx h) indata = unsafeDupablePerformIO $
	secureMemCopy ctx >>= \ctx' ->
	withSecureMemPtr ctx' $ \ctxptr ->
	withSecureMemPtr h $ \hptr ->
	withByteStringPtr indata $ \indatalen indataptr ->
	c_gcm_update ctxptr hptr indatalen indataptr >>
	return (NettleGCM ctx' h)

gcm_crypt
	:: NettleGCMMode
	-> (Ptr Word8 -> Ptr Word8)
	-> FunPtr NettleCryptFunc
	-> SecureMem -> NettleGCM -> B.ByteString -> (B.ByteString, NettleGCM)
gcm_crypt mode encctxoffset encrypt encctx (NettleGCM ctx h) indata = unsafeDupablePerformIO $
	secureMemCopy ctx >>= \ctx' ->
	withSecureMemPtr ctx' $ \ctxptr ->
	withSecureMemPtr h $ \hptr ->
	withSecureMemPtr encctx $ \encctxptr ->
	withByteStringPtr indata $ \indatalen indataptr -> do
	outdata <- B.create (B.length indata) $ \outptr ->
		mode ctxptr hptr (encctxoffset encctxptr) encrypt indatalen outptr indataptr
	return (outdata, NettleGCM ctx' h)

gcm_digest
	:: (Ptr Word8 -> Ptr Word8)
	-> FunPtr NettleCryptFunc
	-> SecureMem -> NettleGCM -> Int -> AuthTag
gcm_digest encctxoffset encrypt encctx (NettleGCM ctx h) taglen = unsafeDupablePerformIO $
	secureMemCopy ctx >>= \ctx' ->
	withSecureMemPtr ctx' $ \ctxptr ->
	withSecureMemPtr h $ \hptr ->
	withSecureMemPtr encctx $ \encctxptr -> do
	tag <- B.create (fromIntegral taglen) $ \tagptr ->
		c_gcm_digest ctxptr hptr (encctxoffset encctxptr) encrypt (fromIntegral taglen) tagptr
	return $ AuthTag tag

stream_crypt
	:: NettleCryptFunc
	-> SecureMem -> B.ByteString -> (B.ByteString, SecureMem)
stream_crypt crypt ctx indata = unsafeDupablePerformIO $
	secureMemCopy ctx >>= \ctx' ->
	withSecureMemPtr ctx' $ \ctxptr ->
	withByteStringPtr indata $ \indatalen indataptr -> do
	outdata <- B.create (B.length indata) $ \outptr ->
		crypt ctxptr indatalen outptr indataptr
	return (outdata, ctx')