//BL2ModHandler by Ethan Alexander Shulman http://mnix.net/
//This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
//http://creativecommons.org/licenses/by-nc-sa/4.0/


#include "d3dx9.h"
#include <stdio.h>
#include <windows.h>


#define DEBUG_OUT


typedef void __fastcall Mod_InitDirectX(IDirect3DDevice9 *d3d9);
typedef void __fastcall Mod_Render(IDirect3DDevice9 *d3d9);
typedef void __fastcall Mod_ResetDirectX(IDirect3DDevice9 *d3d9);


struct Mod {
	char fileName[MAX_PATH];

	bool loaded;
	HMODULE module;

	Mod_InitDirectX *initDirectX;
	Mod_Render *render;
	Mod_ResetDirectX *resetDirectX;
};


char *TITLE = "BL2 Mod Handler v0";
int nMods;
Mod *mods;

byte bb[16];

DWORD d3dBaseAddr;
IDirect3D9 *d3d9;
IDirect3DDevice9 *d3d9Dev;

DWORD d3dCreateAddr,d3dDeviceResetAddr;
DWORD d3dCreateDeviceAddr=0,d3dDevicePresentAddr=0;


typedef IDirect3D9 *__stdcall Direct3DCreate9_DEF(UINT sdkVers);
typedef HRESULT __stdcall IDirect3D9_CreateDevice_DEF(IDirect3D9 *dev, UINT adapt, D3DDEVTYPE devType, HWND focusWindow, DWORD behav, D3DPRESENT_PARAMETERS *d3dpres, IDirect3DDevice9 **ddev);
typedef HRESULT __stdcall IDirect3DDevice9_Present_DEF(IDirect3DDevice9 *dev, const RECT *src, const RECT *dst, HWND overrideWnd, const RGNDATA *dirty);
typedef HRESULT __stdcall IDirect3DDevice9_Reset_DEF(IDirect3DDevice9 *dev, D3DPRESENT_PARAMETERS *d3dpp);

char msgoutbuf[1024];
void msg(char *out) {
	FILE *fp;
	if (fopen_s(&fp, "BL2ModHandlerLog.txt", "a+") != 0) return;
	fprintf(fp, "%s\n", out);
	fclose(fp);
}
void msgi(char *head, int i) {
	sprintf_s(msgoutbuf, "%s: %d", head, i);
	msg(msgoutbuf);
}
void msgx(char *head, int i) {
	sprintf_s(msgoutbuf, "%s: 0x%x",head, i);
	msg(msgoutbuf);
}
void msgs(char *head, char *s) {
	sprintf_s(msgoutbuf, "%s: %s", head, s);
	msg(msgoutbuf);
}
void msg_bytes(char *head, byte *b, int nBytes) {
	sprintf_s(msgoutbuf, "%s: ", head);
	for (int i = 0; i < nBytes-1; i++) {
		sprintf_s(msgoutbuf, "%s0x%x, ", msgoutbuf, b[i]);
	}
	sprintf_s(msgoutbuf, "%s0x%x", msgoutbuf, b[nBytes - 1]);
	msg(msgoutbuf);
}

void mrw(DWORD addr, DWORD sz) {
	DWORD old;
	VirtualProtect((LPVOID)addr, sz, PAGE_READWRITE, &old);
}
void mrwe(DWORD addr, DWORD sz) {
	DWORD old;
	VirtualProtect((LPVOID)addr, sz, PAGE_EXECUTE_READWRITE, &old);
}
void mre(DWORD addr, DWORD sz) {
	DWORD old;
	VirtualProtect((LPVOID)addr, sz, PAGE_EXECUTE_READ, &old);
}
void mro(DWORD addr, DWORD sz) {
	DWORD old;
	VirtualProtect((LPVOID)addr, sz, PAGE_READONLY, &old);
}
void me(DWORD addr, DWORD sz) {
	DWORD old;
	VirtualProtect((LPVOID)addr, sz, PAGE_EXECUTE, &old);
}

bool IsMemoryReadable(DWORD addr) {
	MEMORY_BASIC_INFORMATION mem;
	if (VirtualQuery((void*)addr, &mem, sizeof(mem)) == 0) return false;

	if (mem.State != MEM_COMMIT) return false;
	if (mem.Protect == PAGE_NOACCESS || mem.Protect == PAGE_EXECUTE) return false;

	return true;
}
bool IsMemoryReadableAndWritable(DWORD addr) {
	MEMORY_BASIC_INFORMATION mem;
	if (VirtualQuery((void*)addr, &mem, sizeof(mem)) == 0) return false;

	if (mem.State != MEM_COMMIT) return false;
	if (!(mem.Protect == PAGE_READWRITE || mem.Protect == PAGE_EXECUTE_READWRITE)) return false;

	return true;
}

DWORD writeVTableSafe(DWORD obj, DWORD offset, DWORD newAddr) {//returns 0 if failed, returns ptr to old vtable value if succeeded
	DWORD *vtbl = (DWORD*)((*((DWORD*)obj)) + offset);

	if (!IsMemoryReadable((DWORD)vtbl)) return 0;
	DWORD old = *vtbl;
	
	mrw(obj, 4);
	if (!IsMemoryReadableAndWritable(obj)) return 0;
	*vtbl = newAddr;
	mre(obj, 4);
	return old;
}
DWORD writeVTable(DWORD obj, DWORD offset, DWORD newAddr) {
	//get ptr to vtable
	DWORD *vtbl = (DWORD*)( (*(DWORD*)obj) + offset);
	//grab ptr to old function ptr from vtable
	DWORD old = *vtbl;
	mrw((DWORD)vtbl, 4);
	*vtbl = newAddr;//write in our own function ptr
	mre((DWORD)vtbl, 4);
	return old;
}
DWORD readVTable(DWORD obj, DWORD offset) {
	return *(DWORD*)((*((DWORD*)obj)) + offset);
}
DWORD writeVTableOrOverride(DWORD obj, DWORD offset, DWORD newAddr, DWORD vtblAddr) {
	DWORD present = readVTable(obj, offset);//0x44 = present offset on vtbl
	if (*(byte*)present == 0xE9) {//steam overlay
		DWORD overlayFunc = (present + 5) + *(DWORD*)(present + 1);
#ifdef DEBUG_OUT
		msgx("Detour Override", overlayFunc);
#endif
		memcpy(bb, (void*)overlayFunc, 5);

		mrw(overlayFunc, 10);

		*(byte*)overlayFunc = 0xE8;//call rel
		DWORD pb = newAddr - (overlayFunc + 5);
		memcpy((void*)(overlayFunc + 1), &pb, 4);//copy position to jump

		pb = (*(DWORD*)(bb + 1)) + overlayFunc + 5;//copy overlay jmp location
		pb -= overlayFunc + 10;//change jmp location to compensate for move
		memcpy(bb + 1, &pb, 4);
		memcpy((void*)(overlayFunc + 5), bb, 5);

		mre(overlayFunc, 10);
	}
	else {//no steam overlay
		writeVTable(obj, offset, vtblAddr);
	}
	return present;
}

DWORD findFunctionLength(DWORD addr, DWORD maxByteLen) {
	byte b,b2;
	DWORD i = 0;
	int lr = 5;

	while (i < maxByteLen) {
		b = *(byte*)(addr + i);
		b2 = *(byte*)(addr + i + 1);
		if (b == 0xC2) {//ret with pop
			lr = -1;
		}
		else if (b == 0xC3 && (b2 == 0x90||b2==0xCC||b2==0)) return i;//ret

		i++;
		lr++;
		if (lr == 2 && (b2 == 0x90 || b2 == 0xCC || b2 == 0)) return i;
	}
	
	return maxByteLen;
}
unsigned short findFunctionReturnShort(DWORD funcEnd) {
	byte *bp = (byte*)(funcEnd-3);
	if (bp[0] != 0xC2 && bp[2] == 0xC3) {
		return 0;
	}
	return (bp[2] << 8 | bp[1]);
}


#define MAX_SEARCH_SIZE 128

//essentially makes it so whenever the function at funcPtr is called, hookPtr is called with the return value as a argument, used to hook createdirect3d9 with steam overlay
DWORD hookFunctionEnd(DWORD funcPtr, DWORD hookPtr, DWORD fl) {//returns new allocated executable memory
	unsigned short rets = findFunctionReturnShort(funcPtr + fl);

	DWORD fPtr = funcPtr+fl-1;
	if (rets != 0) {
		fPtr -= 2;
	}
#ifdef DEBUG_OUT
	msg_bytes("Overwritten Bytes", (byte*)fPtr,5);
#endif

	DWORD nfPtr = (DWORD)VirtualAlloc(0, 64, MEM_COMMIT, PAGE_READWRITE);

	byte *bp = (byte*)nfPtr;
	bp[0] = 0x51;//push ecx
	bp[1] = 0x8B;//mov ecx,eax
	bp[2] = 0xC8;//^move return value(eax) to __fastcall argument 1
	bp[3] = 0xE8;//call hookPtr
	DWORD pb = hookPtr - (nfPtr + 8);
	memcpy((void*)(nfPtr + 4), &pb, 4);
	bp[8] = 0x59;//pop ecx
	bp[9] = 0xC2;//ret
	memcpy((void*)(nfPtr + 10), &rets, 2);

	mre(nfPtr, 64);


	mrw(fPtr, fl);

	bp = (byte*)fPtr;
	bp[0] = 0xE9;//jmp rel
	pb = nfPtr - (fPtr + 5);//relative offset of nfPtr
	memcpy((void*)(fPtr + 1), &pb, 4);

	mre(fPtr, fl);

	return nfPtr;
}
//used for when you only have 3 bytes to jmp
DWORD hookFunctionEndJmpShort(DWORD funcPtr, DWORD hookPtr, DWORD fl) {//returns new allocated executable memory
	unsigned short rets = findFunctionReturnShort(funcPtr + fl);

	DWORD fPtr = funcPtr + fl-1;
	if (rets != 0) {
		fPtr -= 2;
	}

	DWORD fptr = fPtr - 127, row = 0;
	byte repV = 0, rb;
	signed char fpi = -127;
	while (fpi < 127) {//find 5 empty bytes to write jmp too
		rb = *(byte*)fptr;
		if (rb == repV) {
			row++;
			if (row > 4 && (rb == 0 || rb == 0xCC || rb == 0x90)) {
				fpi -= 6;
				fptr -= 4;
				break;
			}
		}
		else {
			row = 1;
		}
		repV = rb;

		fptr++;
		fpi++;
	}

#ifdef DEBUG_OUT
	msg_bytes("Overwritten Bytes", (byte*)fPtr, 2);
#endif
	DWORD nfPtr = (DWORD)VirtualAlloc(0, 64, MEM_COMMIT, PAGE_READWRITE);

	byte *bp = (byte*)nfPtr;
	bp[0] = 0x51;//push ecx
	bp[1] = 0x8B;//mov ecx,eax
	bp[2] = 0xC8;//^move return value(eax) to __fastcall argument 1
	bp[3] = 0xE8;//call hookPtr
	DWORD pb = hookPtr - (nfPtr + 8);
	memcpy((void*)(nfPtr + 4), &pb, 4);
	bp[8] = 0x59;//pop ecx
	bp[9] = 0xC2;//ret
	memcpy((void*)(nfPtr + 10), &rets, 2);

	mre(nfPtr, 64);

	mrw(fPtr, 3);
	bp = (byte*)fPtr;
	bp[0] = 0xEB;//jmp rel
	bp[1] = fpi;
	bp[2] = 0x90;
	mre(fPtr, 3);

#ifdef DEBUG_OUT
	msg_bytes("Other Overwritten Bytes", (byte*)fptr, 5);
#endif
	mrw(fptr, 5);
	*(byte*)fptr = 0xE9;//jmp rel
	*(DWORD*)(fptr + 1) = nfPtr - fptr-5;
	mre(fptr, 5);

	return nfPtr;
}
DWORD hookFunctionEndEx(DWORD funcPtr, DWORD hookPtr) {
	DWORD fl = findFunctionLength(funcPtr,MAX_SEARCH_SIZE);
	return hookFunctionEnd(funcPtr, hookPtr, fl);
}

void OnPresent() {
	for (int i = 0; i < nMods; i++) {
		if (!mods[i].loaded || mods[i].render == 0) continue;
		mods[i].render(d3d9Dev);
	}
}

HRESULT __fastcall OnReset(HRESULT ret) {

	for (int i = 0; i < nMods; i++) {
		if (!mods[i].loaded || mods[i].resetDirectX == 0) continue;
		mods[i].resetDirectX(d3d9Dev);
	}

	return ret;
}
/*
HRESULT __stdcall IDirect3DDevice9_Reset_WRAPPER(IDirect3DDevice9 *dev, D3DPRESENT_PARAMETERS *d3dpp) {
	MessageBox(0, "wat", "the", 0);
	HRESULT hRes = ((IDirect3DDevice9_Reset_DEF*)d3dDeviceResetAddr)(dev, d3dpp);
	for (int i = 0; i < nMods; i++) {
		if (!mods[i].loaded || mods[i].resetDirectX == 0) continue;
		mods[i].resetDirectX(0);
	}
	return hRes;
}*/

HRESULT __stdcall IDirect3DDevice9_Present_WRAPPER(IDirect3DDevice9 *dev, const RECT *src, const RECT *dst, HWND overWnd, const RGNDATA *dirty) {
	OnPresent();
	HRESULT hRes = ((IDirect3DDevice9_Present_DEF*)d3dDevicePresentAddr)(dev, src, dst, overWnd, dirty);
	return hRes;
}


void OverrideDeviceFunctions() {
	Sleep(100);

	//Present
	d3dDevicePresentAddr = writeVTableOrOverride((DWORD)d3d9Dev, 0x44, (DWORD)&OnPresent, (DWORD)&IDirect3DDevice9_Present_WRAPPER);//0x44 = present offset on vtbl
	
	//Reset
	d3dDeviceResetAddr = readVTable((DWORD)d3d9Dev, 0x40);
	if (*(byte*)d3dDeviceResetAddr == 0xE9) {//steam
		hookFunctionEndJmpShort(d3dDeviceResetAddr, (DWORD)&OnReset, 225);
	}
	else {//no steam
		//todo
	}
}

//wrapper for IDirect3D9->CreateDevice
HRESULT __stdcall IDirect3D9_CreateDevice_WRAPPER(IDirect3D9 *dev, UINT adapt, D3DDEVTYPE devType, HWND focusWindow, DWORD behav, D3DPRESENT_PARAMETERS *d3dpres, IDirect3DDevice9 **ddev) {
	HRESULT hRes = ((IDirect3D9_CreateDevice_DEF*)d3dCreateDeviceAddr)(dev,adapt,devType,focusWindow,behav,d3dpres,ddev);
	
	d3d9Dev = *ddev;
#ifdef DEBUG_OUT
	msgx("IDirect3DDevice9 Ptr", (DWORD)d3d9Dev);
#endif

	for (int i = 0; i < nMods; i++) {
		if (!mods[i].loaded || mods[i].initDirectX == 0) continue;
#ifdef DEBUG_OUT
		msgs("Loading DirectX in mod ", mods[i].fileName);
#endif
		mods[i].initDirectX(d3d9Dev);
	}

	CreateThread(0, 0, (LPTHREAD_START_ROUTINE)OverrideDeviceFunctions, 0, 0, 0);
	return hRes;
}

//wrapper for Direct3DCreate9
IDirect3D9 * __fastcall Direct3DCreate9_WRAPPER(IDirect3D9 *d9) {
	d3d9 = d9;

#ifdef DEBUG_OUT
	msgx("IDirect3D9 Ptr", (DWORD)d9);
#endif

	if (d3dCreateDeviceAddr == 0) {
		d3dCreateDeviceAddr = writeVTable((DWORD)d3d9, 0x40, (DWORD)&IDirect3D9_CreateDevice_WRAPPER);//0x40 = createdevice offset on vtable
	}

	return d9;
}




void LoadMods() {
	//load mod settings file set by BL2ModLauncher
	FILE *fp;
	fopen_s(&fp, "BL2ModSettings.ini", "r");

	int modBufSz = 2;
	mods = (Mod*)malloc(sizeof(Mod)*modBufSz);
	nMods = 0;

	int cc = 0,st = 0;
	char cb[MAX_PATH+1],rc;
	int bib;
	while (!feof(fp)) {//parse
		rc = fgetc(fp);
		if (rc == '=') {
			st = 1;
			cb[cc] = '\0';
			cc++;
			continue;
		}

		if (st == 0) {
			cb[cc] = rc;
			cc++;
		}
		else {
			if (rc == '1') {
				if (nMods == modBufSz) {
					modBufSz *= 4;
					mods = (Mod*)realloc(mods,sizeof(Mod)*modBufSz);
				}
				memcpy(mods[nMods].fileName, cb, cc);
				nMods++;
			}

			fgetc(fp);
			st = 0;
			cc = 0;
		}
	}

	fclose(fp);

	HMODULE hm;
	char fb[MAX_PATH + 7];
	for (int i = 0; i < nMods; i++) {//attempt to load mods
#ifdef DEBUG_OUT
		msgs("Loading ", mods[i].fileName);
#endif
		sprintf_s(fb, "Mods\\%s", mods[i].fileName);
		hm = mods[i].module = LoadLibrary(fb);
		if (hm == 0) {
#ifdef DEBUG_OUT
			msgs("Failed to load mod ", mods[i].fileName);
#endif
			mods[i].loaded = false;
			continue;
		}
		
		mods[i].loaded = true;
		
		//find event functions
		mods[i].initDirectX = (Mod_InitDirectX*)GetProcAddress(hm, "@InitDirectX@4");
		mods[i].render = (Mod_Render*)GetProcAddress(hm, "@Render@4");
		mods[i].resetDirectX = (Mod_ResetDirectX*)GetProcAddress(hm, "@ResetDirectX@4");
	}
}


void Init(HINSTANCE hi) {
#ifdef DEBUG_OUT
	msg("\n\nBL2ModHandler.dll injected");
#endif

	//clear log file
	FILE *fp;
	fopen_s(&fp, "BL2ModHandlerLog.txt", "w");
	fputc(0,fp);
	fclose(fp);

	LoadMods();

	//inject into directx
	d3dBaseAddr = (DWORD)GetModuleHandle("d3d9.dll");

	//grab pointer to direct3dcreate9
	d3dCreateAddr = (DWORD)&Direct3DCreate9;

	//hook direct3dcreate9
	hookFunctionEndEx(d3dCreateAddr,(DWORD)&Direct3DCreate9_WRAPPER);
#ifdef DEBUG_OUT
	msgx("d3d9.dll", d3dBaseAddr);
	msgx("Direct3DCreate9", d3dCreateAddr);
#endif
}

//program entry point
bool WINAPI DllMain(HINSTANCE hInst, DWORD reason, LPVOID reserved) {
	if (reason == DLL_PROCESS_ATTACH) {
		Init(hInst);
	}
	return true;
}