// SPDX-FileCopyrightText: 2015-2018 nodepad <nod3pad@gmail.com>
// SPDX-FileCopyrightText: 2015-2018 pancake <pancake@nopcode.org>
// SPDX-License-Identifier: LGPL-3.0-only

#include "mz.h"
#include <rz_list.h>

static ut64 rz_bin_mz_va_to_la(const ut16 segment, const ut16 offset) {
	return (segment << 4) + offset;
}

static ut64 rz_bin_mz_la_to_pa(const struct rz_bin_mz_obj_t *bin, ut64 la) {
	return la + (bin->dos_header->header_paragraphs << 4);
}

RzBinAddr *rz_bin_mz_get_entrypoint(const struct rz_bin_mz_obj_t *bin) {
	const MZ_image_dos_header *mz;
	ut64 la;
	RzBinAddr *entrypoint;

	if (!bin || !bin->dos_header) {
		return NULL;
	}

	mz = bin->dos_header;
	la = rz_bin_mz_va_to_la(mz->cs, mz->ip);
	la &= 0xfffff;
	if (la >= bin->load_module_size) {
		RZ_LOG_ERROR("The entry point is outside the load module size\n");
		return NULL;
	}
	entrypoint = RZ_NEW0(RzBinAddr);
	if (entrypoint) {
		entrypoint->vaddr = la;
		entrypoint->paddr = rz_bin_mz_la_to_pa(bin, la);
	}

	return entrypoint;
}

static int cmp_sections(const void *a, const void *b) {
	const RzBinSection *s_a, *s_b;

	s_a = a;
	s_b = b;

	return s_a->vaddr - s_b->vaddr;
}

static RzBinSection *rz_bin_mz_init_section(const struct rz_bin_mz_obj_t *bin,
	ut64 laddr) {
	RzBinSection *section;

	section = RZ_NEW0(RzBinSection);
	if (section) {
		section->vaddr = laddr;
	}

	return section;
}

RzPVector /*<RzBinSection *>*/ *rz_bin_mz_get_segments(const struct rz_bin_mz_obj_t *bin) {
	RzPVector *seg_vec;
	void **iter;
	RzBinSection *section;
	MZ_image_relocation_entry *relocs;
	int i, num_relocs, section_number;
	ut16 ss;

	if (!bin || !bin->dos_header) {
		return NULL;
	}

	seg_vec = rz_pvector_new((RzPVectorFree)rz_bin_section_free);
	if (!seg_vec) {
		return NULL;
	}

	/* Add address of first segment to make sure that it is present
	 * even if there are no relocations or there isn't first segment in
	 * the relocations. */
	section = rz_bin_mz_init_section(bin, 0);
	if (!section) {
		goto err_out;
	}
	rz_pvector_push(seg_vec, section);
	rz_pvector_sort(seg_vec, (RzPVectorComparator)cmp_sections, NULL);

	relocs = bin->relocation_entries;
	num_relocs = bin->dos_header->num_relocs;
	for (i = 0; i < num_relocs; i++) {
		RzBinSection c;
		ut64 laddr, paddr, section_laddr;
		ut16 curr_seg;

		laddr = rz_bin_mz_va_to_la(relocs[i].segment, relocs[i].offset);
		if ((laddr + 2) >= bin->load_module_size) {
			continue;
		}

		paddr = rz_bin_mz_la_to_pa(bin, laddr);
		if (rz_buf_size(bin->b) < paddr + 2) {
			continue;
		}

		if (!rz_buf_read_le16_at(bin->b, paddr, &curr_seg)) {
			continue;
		}

		section_laddr = rz_bin_mz_va_to_la(curr_seg, 0);
		if (section_laddr > bin->load_module_size) {
			continue;
		}

		c.vaddr = section_laddr;
		if (rz_pvector_find(seg_vec, &c, (RzPVectorComparator)cmp_sections, NULL)) {
			continue;
		}

		section = rz_bin_mz_init_section(bin, section_laddr);
		if (!section) {
			goto err_out;
		}
		rz_pvector_push(seg_vec, section);
		rz_pvector_sort(seg_vec, (RzPVectorComparator)cmp_sections, NULL);
	}

	/* Add address of stack segment if it's inside the load module. */
	ss = bin->dos_header->ss;
	if (rz_bin_mz_va_to_la(ss, 0) < bin->load_module_size) {
		section = rz_bin_mz_init_section(bin, rz_bin_mz_va_to_la(ss, 0));
		if (!section) {
			goto err_out;
		}
		rz_pvector_push(seg_vec, section);
		rz_pvector_sort(seg_vec, (RzPVectorComparator)cmp_sections, NULL);
	}

	/* Fixup sizes and addresses, set name, permissions and set add flag */
	section_number = 0;
	rz_pvector_foreach (seg_vec, iter) {
		section = *iter;
		section->name = rz_str_newf("seg_%03d", section_number);
		if (section_number) {
			// calculate current index in loop by subtracting base ptr
			ut32 ptr_gap = iter - (void **)seg_vec->v.a;
			ut32 cur_index = ptr_gap / sizeof(void **);
			RzBinSection *p_section = NULL;
			if (cur_index == 0) {
				p_section = *rz_pvector_index_ptr(seg_vec, rz_pvector_len(seg_vec) - 1);
			} else {
				p_section = *rz_pvector_index_ptr(seg_vec, cur_index - 1);
			}
			p_section->size = section->vaddr - p_section->vaddr;
			p_section->vsize = p_section->size;
		}
		section->vsize = section->size;
		section->paddr = rz_bin_mz_la_to_pa(bin, section->vaddr);
		section->perm = rz_str_rwx("rwx");
		section_number++;
	}
	section = *rz_pvector_index_ptr(seg_vec, rz_pvector_len(seg_vec) - 1);
	section->size = bin->load_module_size - section->vaddr;
	section->vsize = section->size;

	return seg_vec;

err_out:
	RZ_LOG_ERROR("Failed to get segment list\n");
	rz_pvector_free(seg_vec);

	return NULL;
}

struct rz_bin_mz_reloc_t *rz_bin_mz_get_relocs(const struct rz_bin_mz_obj_t *bin) {
	int i, j;
	const int num_relocs = bin->dos_header->num_relocs;
	const MZ_image_relocation_entry *const rel_entry = bin->relocation_entries;

	struct rz_bin_mz_reloc_t *relocs = calloc(num_relocs + 1, sizeof(*relocs));
	if (!relocs) {
		RZ_LOG_ERROR("Cannot allocate struct rz_bin_mz_reloc_t\n");
		return NULL;
	}
	for (i = 0, j = 0; i < num_relocs; i++) {
		relocs[j].vaddr = rz_bin_mz_va_to_la(rel_entry[i].segment,
			rel_entry[i].offset);
		relocs[j].paddr = rz_bin_mz_la_to_pa(bin, relocs[j].vaddr);

		/* Add only relocations which resides inside dos executable */
		if (relocs[j].vaddr < bin->load_module_size) {
			j++;
		}
	}
	relocs[j].last = 1;

	return relocs;
}

void *rz_bin_mz_free(struct rz_bin_mz_obj_t *bin) {
	if (!bin) {
		return NULL;
	}
	free((void *)bin->dos_header);
	free((void *)bin->dos_extended_header);
	free((void *)bin->relocation_entries);
	rz_buf_free(bin->b);
	bin->b = NULL;
	free(bin);
	return NULL;
}

static bool read_mz_header(MZ_image_dos_header *mz, RzBuffer *buf) {
	ut64 offset = 0;
	return rz_buf_read_le16_offset(buf, &offset, &mz->signature) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->bytes_in_last_block) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->blocks_in_file) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->num_relocs) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->header_paragraphs) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->min_extra_paragraphs) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->max_extra_paragraphs) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->ss) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->sp) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->checksum) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->ip) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->cs) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->reloc_table_offset) &&
		rz_buf_read_le16_offset(buf, &offset, &mz->overlay_number);
}

static int rz_bin_mz_init_hdr(struct rz_bin_mz_obj_t *bin) {
	int relocations_size, dos_file_size;
	MZ_image_dos_header *mz;
	if (!(mz = RZ_NEW0(MZ_image_dos_header))) {
		RZ_LOG_ERROR("Cannot allocate MZ_image_dos_header");
		return false;
	}
	if (!read_mz_header(mz, bin->b)) {
		RZ_LOG_ERROR("Cannot read MZ_image_dos_header\n");
		return false;
	}
	// dos_header is not endian safe here in this point
	if (mz->blocks_in_file < 1) {
		return false;
	}
	if (mz->bytes_in_last_block == 0) {
		// last block is full
		dos_file_size = mz->blocks_in_file << 9;
	} else {
		// last block is partially full
		dos_file_size = ((mz->blocks_in_file - 1) << 9) +
			mz->bytes_in_last_block;
	}

	bin->dos_header = mz;
	bin->dos_file_size = dos_file_size;
	if (dos_file_size > bin->size) {
		return false;
	}
	bin->load_module_size = dos_file_size - (mz->header_paragraphs << 4);
	relocations_size = mz->num_relocs * sizeof(MZ_image_relocation_entry);
	if ((mz->reloc_table_offset + relocations_size) > bin->size) {
		return false;
	}

	sdb_num_set(bin->kv, "mz.initial.cs", mz->cs, 0);
	sdb_num_set(bin->kv, "mz.initial.ip", mz->ip, 0);
	sdb_num_set(bin->kv, "mz.initial.ss", mz->ss, 0);
	sdb_num_set(bin->kv, "mz.initial.sp", mz->sp, 0);
	sdb_num_set(bin->kv, "mz.overlay_number", mz->overlay_number, 0);
	sdb_num_set(bin->kv, "mz.dos_header.offset", 0, 0);
	sdb_set(bin->kv, "mz.dos_header.format", "[2]zwwwwwwwwwwwww"
						 " signature bytes_in_last_block blocks_in_file num_relocs "
						 " header_paragraphs min_extra_paragraphs max_extra_paragraphs "
						 " ss sp checksum ip cs reloc_table_offset overlay_number ",
		0);

	bin->dos_extended_header_size = mz->reloc_table_offset -
		sizeof(MZ_image_dos_header);

	if (bin->dos_extended_header_size > 0) {
		if (!(bin->dos_extended_header =
				    malloc(bin->dos_extended_header_size))) {
			RZ_LOG_ERROR("Cannot allocate dos extended header");
			return false;
		}
		if (rz_buf_read_at(bin->b, sizeof(MZ_image_dos_header),
			    (ut8 *)bin->dos_extended_header,
			    bin->dos_extended_header_size) == -1) {
			RZ_LOG_ERROR("Cannot read dos extended header\n");
			return false;
		}
	}

	if (relocations_size == 0) {
		return true;
	}
	if (!(bin->relocation_entries = malloc(relocations_size))) {
		RZ_LOG_ERROR("Cannot allocate dos relocation entries");
		return false;
	}
	ut64 offset = bin->dos_header->reloc_table_offset;
	int i;
	for (i = 0; i < relocations_size / sizeof(MZ_image_relocation_entry); i++) {
		MZ_image_relocation_entry *mz_rel_entry = bin->relocation_entries + i;
		if (!rz_buf_read_le16_offset(bin->b, &offset, &mz_rel_entry->offset) ||
			!rz_buf_read_le16_offset(bin->b, &offset, &mz_rel_entry->segment)) {
			RZ_FREE(bin->relocation_entries);
			return false;
		}
	}
	return true;
}

static bool rz_bin_mz_init(struct rz_bin_mz_obj_t *bin) {
	bin->dos_header = NULL;
	bin->dos_extended_header = NULL;
	bin->relocation_entries = NULL;
	bin->kv = sdb_new0();
	if (!rz_bin_mz_init_hdr(bin)) {
		RZ_LOG_WARN("File is not MZ\n");
		return false;
	}
	return true;
}

struct rz_bin_mz_obj_t *rz_bin_mz_new(const char *file) {
	struct rz_bin_mz_obj_t *bin = RZ_NEW0(struct rz_bin_mz_obj_t);
	if (!bin) {
		return NULL;
	}
	bin->file = file;
	size_t binsz;
	ut8 *buf = (ut8 *)rz_file_slurp(file, &binsz);
	bin->size = binsz;
	if (!buf) {
		return rz_bin_mz_free(bin);
	}
	bin->b = rz_buf_new_with_bytes(NULL, 0);
	if (!rz_buf_set_bytes(bin->b, buf, bin->size)) {
		free((void *)buf);
		return rz_bin_mz_free(bin);
	}
	free((void *)buf);
	if (!rz_bin_mz_init(bin)) {
		return rz_bin_mz_free(bin);
	}
	return bin;
}

struct rz_bin_mz_obj_t *rz_bin_mz_new_buf(RzBuffer *buf) {
	struct rz_bin_mz_obj_t *bin = RZ_NEW0(struct rz_bin_mz_obj_t);
	if (!bin) {
		return NULL;
	}
	bin->b = rz_buf_new_with_buf(buf);
	if (!bin->b) {
		return rz_bin_mz_free(bin);
	}
	bin->size = rz_buf_size(buf);
	return rz_bin_mz_init(bin) ? bin : rz_bin_mz_free(bin);
}

RzBinAddr *rz_bin_mz_get_main_vaddr(struct rz_bin_mz_obj_t *bin) {
	int n;
	ut8 b[512];
	if (!bin || !bin->b) {
		return NULL;
	}
	RzBinAddr *entry = rz_bin_mz_get_entrypoint(bin);
	if (!entry) {
		return NULL;
	}
	ZERO_FILL(b);
	if (rz_buf_read_at(bin->b, entry->paddr, b, sizeof(b)) < 0) {
		RZ_LOG_ERROR("Cannot read entry at 0x%16" PFMT64x "\n", (ut64)entry->paddr);
		free(entry);
		return NULL;
	}
	// MSVC
	if (b[0] == 0xb4 && b[1] == 0x30) {
		// ff 36 XX XX			push	XXXX
		// ff 36 XX XX			push	argv
		// ff 36 XX XX			push	argc
		// 9a XX XX XX XX		lcall	_main
		// 50				push	ax
		for (n = 0; n < sizeof(b) - 18; n++) {
			if (b[n] == 0xff && b[n + 4] == 0xff && b[n + 8] == 0xff && b[n + 12] == 0x9a && b[n + 17] == 0x50) {
				const ut16 call_addr = rz_read_ble16(b + n + 13, 0);
				const ut16 call_seg = rz_read_ble16(b + n + 15, 0);
				entry->vaddr = rz_bin_mz_va_to_la(call_seg, call_addr);
				entry->paddr = rz_bin_mz_la_to_pa(bin, entry->vaddr);
				return entry;
			}
		}
	}

	RZ_FREE(entry);
	return NULL;
}
