/* ekko.c - symmetric encryption utility
 *
 * Use like:
 *   ekko -g             generate a new key
 *   ekko -e p c         encrypt p to c
 *   ekko -d c p         decrypt c to p
 *
 * Keys for -e and -d are supplied as the first line of stdin, encoded as hex.
 * The key generated by -g is emitted on stdout. Decryption will fail (without
 * creating the output file at all) if the ciphertext has been modified.
 *
 * Needs bearssl to work. We use AES-256 in EAX mode with a random 256-bit
 * nonce. Encrypted files have a header:
 *   'E' 'K' 'K' 'O' 0 0 0 0 (8 bytes)
 *   nonce (32 bytes)
 *   EAX tag (16 bytes)
 * so the actual encrypted data starts at offset 56. Since EAX is an AEAD mode,
 * the authenticator tag proves that the ciphertext hasn't been modified.
 *
 * Keys have no structure; they are just 256 random bits encoded as hex.
 */

#include <bearssl_aead.h>
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/random.h>
#include <unistd.h>

#define KEYLEN   32
#define NONCELEN 32
#define TAGLEN   16

const uint8_t FILE_MAGIC[] = { 'E', 'K', 'K', 'O', 0, 0, 0, 0 };

void panic(const char *msg) {
	fprintf(stderr, "panic: %s\n", msg);
	_exit(1);
}

void panics(const char *msg, const char *arg) {
	fprintf(stderr, "panic: %s (%s)\n", msg, arg);
	_exit(1);
}

void hex(const uint8_t *in, char *out, size_t bytes) {
	const char HEXDIGITS[] = "0123456789abcdef";
	size_t i;

	for (i = 0; i < bytes; i++) {
		out[i * 2] = HEXDIGITS[(in[i] >> 4) & 0xf];
		out[i * 2 + 1] = HEXDIGITS[in[i] & 0xf];
	}
	out[i * 2] = 0;
}

int unhexc(char c) {
	if (c >= '0' && c <= '9')
		return c - '0';
	else if (c >= 'a' && c <= 'f')
		return c - 'a' + 10;
	else
		panic("bogus hex character?");
	return 0;
}

void unhex(const char *in, uint8_t *out, size_t bytes) {
	size_t i;

	for (i = 0; i < bytes; i++)
		out[i] = (unhexc(in[i * 2]) << 4) | unhexc(in[i * 2 + 1]);
}

int usage(const char *progn) {
	printf("Usage: %s -g             generate key\n", progn);
	printf("       %s -e <p> <c>     encrypt p to c\n", progn);
	printf("       %s -d <c> <p>     decrypt c to p\n", progn);
	return 1;
}

int gen() {
	uint8_t key[KEYLEN];
	char hexkey[KEYLEN * 2 + 1];

	if (getrandom(key, sizeof(key), 0) != sizeof(key))
		panic("getrandom() failed");

	hex(key, hexkey, sizeof(key));
	printf("%s\n", hexkey);
	return 0;
}

void readkey(uint8_t *key) {
	char hexkey[KEYLEN * 2 + 1];
	char *e;

	fgets(hexkey, sizeof(hexkey), stdin);
	e = strchr(hexkey, '\n');
	if (e)
		*e = 0;
	if (strlen(hexkey) != KEYLEN * 2)
		panic("truncated key on stdin");
	unhex(hexkey, key, KEYLEN);
}

struct cryptostate {
	br_aes_gen_ctrcbc_keys bc;
	br_eax_context ctx;
};

void cryptosetup(struct cryptostate *cs, const uint8_t *key,
                 const uint8_t *nonce) {
	const br_block_ctrcbc_class *vt = &br_aes_big_ctrcbc_vtable;
	memset(cs, 0, sizeof *cs);
	vt->init(&cs->bc.vtable, key, KEYLEN);
	br_eax_init(&cs->ctx, &cs->bc.vtable);
	br_eax_reset(&cs->ctx, nonce, NONCELEN);

	/* We don't use AAD at all, so we can flip the context immediately. */
	br_eax_flip(&cs->ctx);
}

/* Like fwrite() and fread(), but they require a successful write/read of
 * exactly the specified number of bytes, or they panic. */
void fwritex(const uint8_t *buf, size_t len, FILE *f) {
	if (fwrite(buf, len, 1, f) != 1)
		panic("short fwrite()");
}

void freadx(uint8_t *buf, size_t len, FILE *f) {
	if (fread(buf, len, 1, f) != 1)
		panic("short fread()");
}

FILE *opentemp(const char *ofn, char *path) {
	char template[_POSIX_PATH_MAX];
	int fd;
	FILE *f;

	snprintf(template, sizeof(template), "%s.XXXXXX", ofn);

	if ((fd = mkstemp(template)) < 0)
		panic("can't make temporary file");
	f = fdopen(fd, "wb");
	if (!f)
		panic("can't fdopen");
	strncpy(path, template, _POSIX_PATH_MAX);
	return f;
}

void stream(struct cryptostate *cs, int encrypt, FILE *in, FILE *out,
            uint8_t *tag) {
	uint8_t buf[4096];
	ssize_t r;

	while (!feof(in) && !ferror(in) && !ferror(out)) {
		r = fread(buf, 1, sizeof(buf), in);
		br_eax_run(&cs->ctx, encrypt, buf, r);
		fwritex(buf, r, out);
	}

	if (ferror(in))
		panic("broken input file");
	if (ferror(out))
		panic("broken output file");

	br_eax_get_tag(&cs->ctx, tag);
}

int enc(const char *pfn, const char *cfn) {
	uint8_t key[KEYLEN];
	uint8_t nonce[NONCELEN];
	uint8_t tag[TAGLEN];

	struct cryptostate crypto;

	FILE *pf, *tf;
	char temppath[_POSIX_PATH_MAX];

	readkey(key);
	if (getrandom(nonce, sizeof(nonce), 0) != sizeof(nonce))
		panic("getrandom() failed");
	cryptosetup(&crypto, key, nonce);

	pf = fopen(pfn, "rb");
	if (!pf)
		panics("can't open plaintext file", pfn);
	tf = opentemp(cfn, temppath);
	if (!tf)
		panic("can't open temp file");

	memset(tag, 0, sizeof(tag));

	fwritex(FILE_MAGIC, sizeof(FILE_MAGIC), tf);
	fwritex(nonce, sizeof(nonce), tf);
	/* write the all-0 auth tag out, we'll fill it in later */
	fwritex(tag, sizeof(tag), tf);

	stream(&crypto, 1, pf, tf, tag);
	fseek(tf, sizeof(FILE_MAGIC) + sizeof(nonce), SEEK_SET);
	fwritex(tag, sizeof(tag), tf);

	if (rename(temppath, cfn))
		panics("can't rename ciphertext file", cfn);

	fclose(tf);
	fclose(pf);

	return 0;
}

int dec(const char *cfn, const char *pfn) {
	uint8_t key[KEYLEN];
	uint8_t nonce[NONCELEN];
	uint8_t ftag[TAGLEN];
	uint8_t tag[TAGLEN];
	uint8_t magic[sizeof(FILE_MAGIC)];

	struct cryptostate crypto;

	FILE *cf, *tf;
	char temppath[_POSIX_PATH_MAX];

	readkey(key);

	cf = fopen(cfn, "rb");
	if (!cf)
		panics("can't open ciphertext file", cfn);

	freadx(magic, sizeof(magic), cf);
	freadx(nonce, sizeof(nonce), cf);
	freadx(ftag, sizeof(ftag), cf);

	if (memcmp(magic, FILE_MAGIC, sizeof(magic)))
		panics("not an ekko-encrypted file", cfn);

	cryptosetup(&crypto, key, nonce);

	tf = opentemp(pfn, temppath);
	if (!tf)
		panic("can't open temp file");

	stream(&crypto, 0, cf, tf, tag);
	if (memcmp(ftag, tag, sizeof(ftag)))
		panic("authentication failed");

	if (rename(temppath, pfn))
		panics("can't rename plaintext file", pfn);

	fclose(tf);
	fclose(cf);

	return 0;
}

int main(int argc, char *argv[]) {
	if (argc == 2 && !strcmp(argv[1], "-g"))
		return gen();
	else if (argc == 4 && !strcmp(argv[1], "-e"))
		return enc(argv[2], argv[3]);
	else if (argc == 4 && !strcmp(argv[1], "-d"))
		return dec(argv[2], argv[3]);
	else
		return usage(argv[0]);
}
