/*
	This file is part of solidity.
	solidity is free software: you can redistribute it and/or modify
	it under the terms of the GNU General Public License as published by
	the Free Software Foundation, either version 3 of the License, or
	(at your option) any later version.
	solidity is distributed in the hope that it will be useful,
	but WITHOUT ANY WARRANTY; without even the implied warranty of
	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
	GNU General Public License for more details.
	You should have received a copy of the GNU General Public License
	along with solidity.  If not, see .
*/
/**
 * Component that can generate various useful Yul functions.
 */
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
using namespace dev;
using namespace dev::solidity;
string YulUtilFunctions::combineExternalFunctionIdFunction()
{
	string functionName = "combine_external_function_id";
	return m_functionCollector->createFunction(functionName, [&]() {
		return Whiskers(R"(
			function (addr, selector) -> combined {
				combined := (or((addr), and(selector, 0xffffffff)))
			}
		)")
		("functionName", functionName)
		("shl32", shiftLeftFunction(32))
		("shl64", shiftLeftFunction(64))
		.render();
	});
}
string YulUtilFunctions::splitExternalFunctionIdFunction()
{
	string functionName = "split_external_function_id";
	return m_functionCollector->createFunction(functionName, [&]() {
		return Whiskers(R"(
			function (combined) -> addr, selector {
				combined := (combined)
				selector := and(combined, 0xffffffff)
				addr := (combined)
			}
		)")
		("functionName", functionName)
		("shr32", shiftRightFunction(32))
		("shr64", shiftRightFunction(64))
		.render();
	});
}
string YulUtilFunctions::copyToMemoryFunction(bool _fromCalldata)
{
	string functionName = "copy_" + string(_fromCalldata ? "calldata" : "memory") + "_to_memory";
	return m_functionCollector->createFunction(functionName, [&]() {
		if (_fromCalldata)
		{
			return Whiskers(R"(
				function (src, dst, length) {
					calldatacopy(dst, src, length)
					// clear end
					mstore(add(dst, length), 0)
				}
			)")
			("functionName", functionName)
			.render();
		}
		else
		{
			return Whiskers(R"(
				function (src, dst, length) {
					let i := 0
					for { } lt(i, length) { i := add(i, 32) }
					{
						mstore(add(dst, i), mload(add(src, i)))
					}
					if gt(i, length)
					{
						// clear end
						mstore(add(dst, length), 0)
					}
				}
			)")
			("functionName", functionName)
			.render();
		}
	});
}
string YulUtilFunctions::leftAlignFunction(Type const& _type)
{
	string functionName = string("leftAlign_") + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		Whiskers templ(R"(
			function (value) -> aligned {
				
			}
		)");
		templ("functionName", functionName);
		switch (_type.category())
		{
		case Type::Category::Address:
			templ("body", "aligned := " + leftAlignFunction(IntegerType(160)) + "(value)");
			break;
		case Type::Category::Integer:
		{
			IntegerType const& type = dynamic_cast(_type);
			if (type.numBits() == 256)
				templ("body", "aligned := value");
			else
				templ("body", "aligned := " + shiftLeftFunction(256 - type.numBits()) + "(value)");
			break;
		}
		case Type::Category::RationalNumber:
			solAssert(false, "Left align requested for rational number.");
			break;
		case Type::Category::Bool:
			templ("body", "aligned := " + leftAlignFunction(IntegerType(8)) + "(value)");
			break;
		case Type::Category::FixedPoint:
			solUnimplemented("Fixed point types not implemented.");
			break;
		case Type::Category::Array:
		case Type::Category::Struct:
			solAssert(false, "Left align requested for non-value type.");
			break;
		case Type::Category::FixedBytes:
			templ("body", "aligned := value");
			break;
		case Type::Category::Contract:
			templ("body", "aligned := " + leftAlignFunction(*TypeProvider::address()) + "(value)");
			break;
		case Type::Category::Enum:
		{
			unsigned storageBytes = dynamic_cast(_type).storageBytes();
			templ("body", "aligned := " + leftAlignFunction(IntegerType(8 * storageBytes)) + "(value)");
			break;
		}
		case Type::Category::InaccessibleDynamic:
			solAssert(false, "Left align requested for inaccessible dynamic type.");
			break;
		default:
			solAssert(false, "Left align of type " + _type.identifier() + " requested.");
		}
		return templ.render();
	});
}
string YulUtilFunctions::shiftLeftFunction(size_t _numBits)
{
	solAssert(_numBits < 256, "");
	string functionName = "shift_left_" + to_string(_numBits);
	if (m_evmVersion.hasBitwiseShifting())
	{
		return m_functionCollector->createFunction(functionName, [&]() {
			return
				Whiskers(R"(
				function (value) -> newValue {
					newValue := shl(, value)
				}
				)")
				("functionName", functionName)
				("numBits", to_string(_numBits))
				.render();
		});
	}
	else
	{
		return m_functionCollector->createFunction(functionName, [&]() {
			return
				Whiskers(R"(
				function (value) -> newValue {
					newValue := mul(value, )
				}
				)")
				("functionName", functionName)
				("multiplier", toCompactHexWithPrefix(u256(1) << _numBits))
				.render();
		});
	}
}
string YulUtilFunctions::shiftRightFunction(size_t _numBits)
{
	solAssert(_numBits < 256, "");
	// Note that if this is extended with signed shifts,
	// the opcodes SAR and SDIV behave differently with regards to rounding!
	string functionName = "shift_right_" + to_string(_numBits) + "_unsigned";
	if (m_evmVersion.hasBitwiseShifting())
	{
		return m_functionCollector->createFunction(functionName, [&]() {
			return
				Whiskers(R"(
				function (value) -> newValue {
					newValue := shr(, value)
				}
				)")
				("functionName", functionName)
				("numBits", to_string(_numBits))
				.render();
		});
	}
	else
	{
		return m_functionCollector->createFunction(functionName, [&]() {
			return
				Whiskers(R"(
				function (value) -> newValue {
					newValue := div(value, )
				}
				)")
				("functionName", functionName)
				("multiplier", toCompactHexWithPrefix(u256(1) << _numBits))
				.render();
		});
	}
}
string YulUtilFunctions::roundUpFunction()
{
	string functionName = "round_up_to_mul_of_32";
	return m_functionCollector->createFunction(functionName, [&]() {
		return
			Whiskers(R"(
			function (value) -> result {
				result := and(add(value, 31), not(31))
			}
			)")
			("functionName", functionName)
			.render();
	});
}
string YulUtilFunctions::overflowCheckedUIntAddFunction(size_t _bits)
{
	solAssert(0 < _bits && _bits <= 256 && _bits % 8 == 0, "");
	string functionName = "checked_add_uint_" + to_string(_bits);
	return m_functionCollector->createFunction(functionName, [&]() {
		if (_bits < 256)
			return
				Whiskers(R"(
				function (x, y) -> sum {
					let mask := 
					sum := add(and(x, mask), and(y, mask))
					if and(sum, not(mask)) { revert(0, 0) }
				}
				)")
				("functionName", functionName)
				("mask", toCompactHexWithPrefix((u256(1) << _bits) - 1))
				.render();
		else
			return
				Whiskers(R"(
				function (x, y) -> sum {
					sum := add(x, y)
					if lt(sum, x) { revert(0, 0) }
				}
				)")
				("functionName", functionName)
				.render();
	});
}
string YulUtilFunctions::arrayLengthFunction(ArrayType const& _type)
{
	string functionName = "array_length_" + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		Whiskers w(R"(
			function (value) -> length {
				
			}
		)");
		w("functionName", functionName);
		string body;
		if (!_type.isDynamicallySized())
			body = "length := " + toCompactHexWithPrefix(_type.length());
		else
		{
			switch (_type.location())
			{
			case DataLocation::CallData:
				solAssert(false, "called regular array length function on calldata array");
				break;
			case DataLocation::Memory:
				body = "length := mload(value)";
				break;
			case DataLocation::Storage:
				if (_type.isByteArray())
				{
					// Retrieve length both for in-place strings and off-place strings:
					// Computes (x & (0x100 * (ISZERO (x & 1)) - 1)) / 2
					// i.e. for short strings (x & 1 == 0) it does (x & 0xff) / 2 and for long strings it
					// computes (x & (-1)) / 2, which is equivalent to just x / 2.
					body = R"(
						length := sload(value)
						let mask := sub(mul(0x100, iszero(and(length, 1))), 1)
						length := div(and(length, mask), 2)
					)";
				}
				else
					body = "length := sload(value)";
				break;
			}
		}
		solAssert(!body.empty(), "");
		w("body", body);
		return w.render();
	});
}
string YulUtilFunctions::arrayAllocationSizeFunction(ArrayType const& _type)
{
	solAssert(_type.dataStoredIn(DataLocation::Memory), "");
	string functionName = "array_allocation_size_" + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		Whiskers w(R"(
			function (length) -> size {
				// Make sure we can allocate memory without overflow
				if gt(length, 0xffffffffffffffff) { revert(0, 0) }
				size := 
				
			}
		)");
		w("functionName", functionName);
		if (_type.isByteArray())
			// Round up
			w("allocationSize", "and(add(length, 0x1f), not(0x1f))");
		else
			w("allocationSize", "mul(length, 0x20)");
		if (_type.isDynamicallySized())
			w("addLengthSlot", "size := add(size, 0x20)");
		else
			w("addLengthSlot", "");
		return w.render();
	});
}
string YulUtilFunctions::arrayDataAreaFunction(ArrayType const& _type)
{
	string functionName = "array_dataslot_" + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		switch (_type.location())
		{
			case DataLocation::Memory:
				if (_type.isDynamicallySized())
					return Whiskers(R"(
						function (memPtr) -> dataPtr {
							dataPtr := add(memPtr, 0x20)
						}
					)")
					("functionName", functionName)
					.render();
				else
					return Whiskers(R"(
						function (memPtr) -> dataPtr {
							dataPtr := memPtr
						}
					)")
					("functionName", functionName)
					.render();
			case DataLocation::Storage:
				if (_type.isDynamicallySized())
				{
					Whiskers w(R"(
						function (slot) -> dataSlot {
							mstore(0, slot)
							dataSlot := keccak256(0, 0x20)
						}
					)");
					w("functionName", functionName);
					return w.render();
				}
				else
				{
					Whiskers w(R"(
						function (slot) -> dataSlot {
							dataSlot := slot
						}
					)");
					w("functionName", functionName);
					return w.render();
				}
			case DataLocation::CallData:
			{
				// Calldata arrays are stored as offset of the data area and length
				// on the stack, so the offset already points to the data area.
				// This might change, if calldata arrays are stored in a single
				// stack slot at some point.
				Whiskers w(R"(
					function (slot) -> dataSlot {
						dataSlot := slot
					}
				)");
				w("functionName", functionName);
				return w.render();
			}
			default:
				solAssert(false, "");
		}
	});
}
string YulUtilFunctions::nextArrayElementFunction(ArrayType const& _type)
{
	solAssert(!_type.isByteArray(), "");
	if (_type.dataStoredIn(DataLocation::Storage))
		solAssert(_type.baseType()->storageBytes() > 16, "");
	string functionName = "array_nextElement_" + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		switch (_type.location())
		{
			case DataLocation::Memory:
				return Whiskers(R"(
					function (memPtr) -> nextPtr {
						nextPtr := add(memPtr, 0x20)
					}
				)")
				("functionName", functionName)
				.render();
			case DataLocation::Storage:
				return Whiskers(R"(
					function (slot) -> nextSlot {
						nextSlot := add(slot, 1)
					}
				)")
				("functionName", functionName)
				.render();
			case DataLocation::CallData:
				return Whiskers(R"(
					function (calldataPtr) -> nextPtr {
						nextPtr := add(calldataPtr, )
					}
				)")
				("stride", toCompactHexWithPrefix(_type.baseType()->isDynamicallyEncoded() ? 32 : _type.baseType()->calldataEncodedSize()))
				("functionName", functionName)
				.render();
			default:
				solAssert(false, "");
		}
	});
}
string YulUtilFunctions::allocationFunction()
{
	string functionName = "allocateMemory";
	return m_functionCollector->createFunction(functionName, [&]() {
		return Whiskers(R"(
			function (size) -> memPtr {
				memPtr := mload()
				let newFreePtr := add(memPtr, size)
				// protect against overflow
				if or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr)) { revert(0, 0) }
				mstore(, newFreePtr)
			}
		)")
		("freeMemoryPointer", to_string(CompilerUtils::freeMemoryPointer))
		("functionName", functionName)
		.render();
	});
}
string YulUtilFunctions::conversionFunction(Type const& _from, Type const& _to)
{
	string functionName =
		"convert_" +
		_from.identifier() +
		"_to_" +
		_to.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		Whiskers templ(R"(
			function (value) -> converted {
				
			}
		)");
		templ("functionName", functionName);
		string body;
		auto toCategory = _to.category();
		auto fromCategory = _from.category();
		switch (fromCategory)
		{
		case Type::Category::Address:
			body =
				Whiskers("converted := (value)")
					("convert", conversionFunction(IntegerType(160), _to))
					.render();
			break;
		case Type::Category::Integer:
		case Type::Category::RationalNumber:
		case Type::Category::Contract:
		{
			if (RationalNumberType const* rational = dynamic_cast(&_from))
				solUnimplementedAssert(!rational->isFractional(), "Not yet implemented - FixedPointType.");
			if (toCategory == Type::Category::FixedBytes)
			{
				solAssert(
					fromCategory == Type::Category::Integer || fromCategory == Type::Category::RationalNumber,
					"Invalid conversion to FixedBytesType requested."
				);
				FixedBytesType const& toBytesType = dynamic_cast(_to);
				body =
					Whiskers("converted := ((value))")
						("shiftLeft", shiftLeftFunction(256 - toBytesType.numBytes() * 8))
						("clean", cleanupFunction(_from))
						.render();
			}
			else if (toCategory == Type::Category::Enum)
			{
				solAssert(_from.mobileType(), "");
				body =
					Whiskers("converted := ((value))")
					("cleanEnum", cleanupFunction(_to))
					// "mobileType()" returns integer type for rational
					("cleanInt", cleanupFunction(*_from.mobileType()))
					.render();
			}
			else if (toCategory == Type::Category::FixedPoint)
				solUnimplemented("Not yet implemented - FixedPointType.");
			else if (toCategory == Type::Category::Address)
				body =
					Whiskers("converted := (value)")
						("convert", conversionFunction(_from, IntegerType(160)))
						.render();
			else
			{
				solAssert(
					toCategory == Type::Category::Integer ||
					toCategory == Type::Category::Contract,
				"");
				IntegerType const addressType(160);
				IntegerType const& to =
					toCategory == Type::Category::Integer ?
					dynamic_cast(_to) :
					addressType;
				// Clean according to the "to" type, except if this is
				// a widening conversion.
				IntegerType const* cleanupType = &to;
				if (fromCategory != Type::Category::RationalNumber)
				{
					IntegerType const& from =
						fromCategory == Type::Category::Integer ?
						dynamic_cast(_from) :
						addressType;
					if (to.numBits() > from.numBits())
						cleanupType = &from;
				}
				body =
					Whiskers("converted := (value)")
					("cleanInt", cleanupFunction(*cleanupType))
					.render();
			}
			break;
		}
		case Type::Category::Bool:
		{
			solAssert(_from == _to, "Invalid conversion for bool.");
			body =
				Whiskers("converted := (value)")
				("clean", cleanupFunction(_from))
				.render();
			break;
		}
		case Type::Category::FixedPoint:
			solUnimplemented("Fixed point types not implemented.");
			break;
		case Type::Category::Array:
			solUnimplementedAssert(false, "Array conversion not implemented.");
			break;
		case Type::Category::Struct:
			solUnimplementedAssert(false, "Struct conversion not implemented.");
			break;
		case Type::Category::FixedBytes:
		{
			FixedBytesType const& from = dynamic_cast(_from);
			if (toCategory == Type::Category::Integer)
				body =
					Whiskers("converted := ((value))")
					("shift", shiftRightFunction(256 - from.numBytes() * 8))
					("convert", conversionFunction(IntegerType(from.numBytes() * 8), _to))
					.render();
			else if (toCategory == Type::Category::Address)
				body =
					Whiskers("converted := (value)")
						("convert", conversionFunction(_from, IntegerType(160)))
						.render();
			else
			{
				// clear for conversion to longer bytes
				solAssert(toCategory == Type::Category::FixedBytes, "Invalid type conversion requested.");
				body =
					Whiskers("converted := (value)")
					("clean", cleanupFunction(from))
					.render();
			}
			break;
		}
		case Type::Category::Function:
		{
			solAssert(false, "Conversion should not be called for function types.");
			break;
		}
		case Type::Category::Enum:
		{
			solAssert(toCategory == Type::Category::Integer || _from == _to, "");
			EnumType const& enumType = dynamic_cast(_from);
			body =
				Whiskers("converted := (value)")
				("clean", cleanupFunction(enumType))
				.render();
			break;
		}
		case Type::Category::Tuple:
		{
			solUnimplementedAssert(false, "Tuple conversion not implemented.");
			break;
		}
		default:
			solAssert(false, "");
		}
		solAssert(!body.empty(), _from.canonicalName() + " to " + _to.canonicalName());
		templ("body", body);
		return templ.render();
	});
}
string YulUtilFunctions::cleanupFunction(Type const& _type)
{
	string functionName = string("cleanup_") + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		Whiskers templ(R"(
			function (value) -> cleaned {
				
			}
		)");
		templ("functionName", functionName);
		switch (_type.category())
		{
		case Type::Category::Address:
			templ("body", "cleaned := " + cleanupFunction(IntegerType(160)) + "(value)");
			break;
		case Type::Category::Integer:
		{
			IntegerType const& type = dynamic_cast(_type);
			if (type.numBits() == 256)
				templ("body", "cleaned := value");
			else if (type.isSigned())
				templ("body", "cleaned := signextend(" + to_string(type.numBits() / 8 - 1) + ", value)");
			else
				templ("body", "cleaned := and(value, " + toCompactHexWithPrefix((u256(1) << type.numBits()) - 1) + ")");
			break;
		}
		case Type::Category::RationalNumber:
			templ("body", "cleaned := value");
			break;
		case Type::Category::Bool:
			templ("body", "cleaned := iszero(iszero(value))");
			break;
		case Type::Category::FixedPoint:
			solUnimplemented("Fixed point types not implemented.");
			break;
		case Type::Category::Function:
			solAssert(dynamic_cast(_type).kind() == FunctionType::Kind::External, "");
			templ("body", "cleaned := " + cleanupFunction(FixedBytesType(24)) + "(value)");
			break;
		case Type::Category::Array:
		case Type::Category::Struct:
		case Type::Category::Mapping:
			solAssert(_type.dataStoredIn(DataLocation::Storage), "Cleanup requested for non-storage reference type.");
			templ("body", "cleaned := value");
			break;
		case Type::Category::FixedBytes:
		{
			FixedBytesType const& type = dynamic_cast(_type);
			if (type.numBytes() == 32)
				templ("body", "cleaned := value");
			else if (type.numBytes() == 0)
				// This is disallowed in the type system.
				solAssert(false, "");
			else
			{
				size_t numBits = type.numBytes() * 8;
				u256 mask = ((u256(1) << numBits) - 1) << (256 - numBits);
				templ("body", "cleaned := and(value, " + toCompactHexWithPrefix(mask) + ")");
			}
			break;
		}
		case Type::Category::Contract:
		{
			AddressType addressType(dynamic_cast(_type).isPayable() ?
				StateMutability::Payable :
				StateMutability::NonPayable
			);
			templ("body", "cleaned := " + cleanupFunction(addressType) + "(value)");
			break;
		}
		case Type::Category::Enum:
		{
			// Out of range enums cannot be truncated unambigiously and therefore it should be an error.
			templ("body", "cleaned := value " + validatorFunction(_type) + "(value)");
			break;
		}
		case Type::Category::InaccessibleDynamic:
			templ("body", "cleaned := 0");
			break;
		default:
			solAssert(false, "Cleanup of type " + _type.identifier() + " requested.");
		}
		return templ.render();
	});
}
string YulUtilFunctions::validatorFunction(Type const& _type, bool _revertOnFailure)
{
	string functionName = string("validator_") + (_revertOnFailure ? "revert_" : "assert_") + _type.identifier();
	return m_functionCollector->createFunction(functionName, [&]() {
		Whiskers templ(R"(
			function (value) {
				if iszero() {  }
			}
		)");
		templ("functionName", functionName);
		if (_revertOnFailure)
			templ("failure", "revert(0, 0)");
		else
			templ("failure", "invalid()");
		switch (_type.category())
		{
		case Type::Category::Address:
		case Type::Category::Integer:
		case Type::Category::RationalNumber:
		case Type::Category::Bool:
		case Type::Category::FixedPoint:
		case Type::Category::Function:
		case Type::Category::Array:
		case Type::Category::Struct:
		case Type::Category::Mapping:
		case Type::Category::FixedBytes:
		case Type::Category::Contract:
		{
			templ("condition", "eq(value, " + cleanupFunction(_type) + "(value))");
			break;
		}
		case Type::Category::Enum:
		{
			size_t members = dynamic_cast(_type).numberOfMembers();
			solAssert(members > 0, "empty enum should have caused a parser error.");
			templ("condition", "lt(value, " + to_string(members) + ")");
			break;
		}
		case Type::Category::InaccessibleDynamic:
			templ("condition", "1");
			break;
		default:
			solAssert(false, "Validation of type " + _type.identifier() + " requested.");
		}
		return templ.render();
	});
}
string YulUtilFunctions::suffixedVariableNameList(string const& _baseName, size_t _startSuffix, size_t _endSuffix)
{
	string result;
	if (_startSuffix < _endSuffix)
	{
		result = _baseName + to_string(_startSuffix++);
		while (_startSuffix < _endSuffix)
			result += ", " + _baseName + to_string(_startSuffix++);
	}
	else if (_endSuffix < _startSuffix)
	{
		result = _baseName + to_string(_endSuffix++);
		while (_endSuffix < _startSuffix)
			result = _baseName + to_string(_endSuffix++) + ", " + result;
	}
	return result;
}