package de.uni_frankfurt.prgpr.core;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;

public class InMemoryJarClassLoader extends ClassLoader {
	private byte[] data;

	final Set<String> FORBIDDEN_CLASSES = new HashSet<>();
	
	{
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.ClientInterface");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.ChannelInterface");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.ServerInterface");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.ServerServiceInterface");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.LogInterface");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.NewServerInterface");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Config");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase3.gui.UIPanel");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase3.gui.UIEventDelegate");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase3.gui.Drawer");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase2.gui.UIPanel");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase2.gui.UIEventDelegate");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase2.gui.Drawer");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase3.storage.ClientConfig");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.phase3.storage.UserConfig");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$Login");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$Code");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$Connect");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$Update");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$Client");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$GetClient");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$Error");
		FORBIDDEN_CLASSES.add("de.uni_frankfurt.prgpr.core.Message$OK");
	}
	
	private Map<String, byte[]> classes = new HashMap<String, byte[]>(); 

	public InMemoryJarClassLoader(ClassLoader parent, byte[] data) {
		super(parent);
		try {
			for (String s : FORBIDDEN_CLASSES) {
				super.loadClass(s);
			}
		} catch (ClassNotFoundException exn) {
			exn.printStackTrace();
		} catch (RuntimeException _) {
			/* Ignore; this is a best-effort thing */
		};
		this.data = data;
		init();
	}
	
	public void init() {
		try {
			try (JarInputStream jis = new JarInputStream(new ByteArrayInputStream(this.data))) {
				for (JarEntry je = jis.getNextJarEntry(); je != null; je = jis.getNextJarEntry()) {
					String name = je.getName();
					final int classOffset = name.indexOf(".class");
					if (classOffset < 0) {
						// manifest file or other data resource
						continue;
					}
					final String className = name.substring(0, classOffset).replaceAll("/","\\.");
					final int bytesPerBlock = 1024;
					byte data[] = new byte[bytesPerBlock];
					int bytesRead, totalBytesRead = 0;
					while ((bytesRead = jis.read(data, totalBytesRead, bytesPerBlock)) > -1) {
						totalBytesRead += bytesRead;
						byte[] newData = new byte[totalBytesRead + bytesPerBlock];
						System.arraycopy(data, 0, newData, 0, totalBytesRead);
						data = newData;
					}
					byte[] classData = new byte[totalBytesRead];
					System.arraycopy(data, 0, classData, 0, totalBytesRead);
					classes.put(className, classData);
				}
			}
		} catch (IOException e) {
			throw new RuntimeException(e);
		}
	}
	
	public byte[] getData() {
		return this.data;
	}
	
	
	private Class<?>
	loadInMemory(String name) {
		byte[] b = classes.get(name);
		if (b == null) {
			return null;
		}
		return defineClass(name, b, 0, b.length);
	}
	
	@Override
	public Class<?> findClass(String name) throws ClassNotFoundException {
		if (FORBIDDEN_CLASSES.contains(name)) {
			throw new RuntimeException(name);
		}
		Class<?> c = loadInMemory(name);
		if (c == null) {
			return super.findClass(name);
		}
		return c;
	}
	
	@Override
	public Class<?> loadClass(String name, boolean doResolve)  throws ClassNotFoundException {
		synchronized (getClassLoadingLock(name)) {
			Class<?> c = findLoadedClass(name);
			if (c != null) {
				return c;
			}
			c = super.findLoadedClass(name);
			if (c != null) {
				return c;
			}
			if (!FORBIDDEN_CLASSES.contains(name)) {
				c = loadInMemory(name);
			}
			if (c == null) {
				c = super.loadClass(name, doResolve);
			}
			if (c == null) {
				c = findClass(name);
				if (c == null) {
					throw new ClassNotFoundException(name);
				}
			}
			if (doResolve) {
				resolveClass(c);
			}
			return c;
		}
	}
}
