Check if memory writes are overlapping.

This commit is contained in:
chriseth 2020-12-17 18:50:48 +01:00
parent 9328503265
commit 9dda37974d
6 changed files with 190 additions and 11 deletions

View File

@ -28,6 +28,8 @@
#include <libyul/Dialect.h>
#include <libyul/Exceptions.h>
#include <libyul/backends/evm/EVMDialect.h>
#include <libsolutil/CommonData.h>
#include <boost/range/adaptor/reversed.hpp>
@ -43,9 +45,9 @@ DataFlowAnalyzer::DataFlowAnalyzer(
Dialect const& _dialect,
map<YulString, SideEffects> _functionSideEffects
):
m_dialect(_dialect),
m_functionSideEffects(std::move(_functionSideEffects)),
m_knowledgeBase(_dialect, m_value)
m_dialect(_dialect),
m_functionSideEffects(std::move(_functionSideEffects)),
m_knowledgeBase(_dialect, m_value)
{
if (auto const* builtin = _dialect.memoryStoreFunction(YulString{}))
m_storeFunctionName[static_cast<unsigned>(StoreLoadLocation::Memory)] = builtin->name;
@ -86,7 +88,8 @@ void DataFlowAnalyzer::operator()(ExpressionStatement& _statement)
}
else
{
clearKnowledgeIfInvalidated(_statement.expression);
// TODO is it correct to visit after?
clearKnowledgeIfInvalidated(_statement.expression, true);
ASTModifier::operator()(_statement);
}
}
@ -97,7 +100,9 @@ void DataFlowAnalyzer::operator()(Assignment& _assignment)
for (auto const& var: _assignment.variableNames)
names.emplace(var.name);
assertThrow(_assignment.value, OptimizerException, "");
clearKnowledgeIfInvalidated(*_assignment.value);
clearKnowledgeIfInvalidated(*_assignment.value, true);
visit(*_assignment.value);
handleAssignment(names, _assignment.value.get(), false);
}
@ -111,7 +116,7 @@ void DataFlowAnalyzer::operator()(VariableDeclaration& _varDecl)
if (_varDecl.value)
{
clearKnowledgeIfInvalidated(*_varDecl.value);
clearKnowledgeIfInvalidated(*_varDecl.value, true);
visit(*_varDecl.value);
}
@ -120,7 +125,7 @@ void DataFlowAnalyzer::operator()(VariableDeclaration& _varDecl)
void DataFlowAnalyzer::operator()(If& _if)
{
clearKnowledgeIfInvalidated(*_if.condition);
clearKnowledgeIfInvalidated(*_if.condition, true);
InvertibleMap<YulString, YulString> storage = m_storage;
InvertibleMap<YulString, YulString> memory = m_memory;
@ -135,7 +140,7 @@ void DataFlowAnalyzer::operator()(If& _if)
void DataFlowAnalyzer::operator()(Switch& _switch)
{
clearKnowledgeIfInvalidated(*_switch.expression);
clearKnowledgeIfInvalidated(*_switch.expression, true);
visit(*_switch.expression);
set<YulString> assignedVariables;
for (auto& _case: _switch.cases)
@ -358,9 +363,25 @@ void DataFlowAnalyzer::clearKnowledgeIfInvalidated(Block const& _block)
m_memory.clear();
}
void DataFlowAnalyzer::clearKnowledgeIfInvalidated(Expression const& _expr)
void DataFlowAnalyzer::clearKnowledgeIfInvalidated(Expression const& _expr, bool _currentlyVisiting)
{
SideEffectsCollector sideEffects(m_dialect, _expr, &m_functionSideEffects);
if (_currentlyVisiting)
if (auto startLength = isMemoryAreaStore(_expr))
{
set<YulString> keysToErase;
for (auto const& item: m_memory.values)
if (!m_knowledgeBase.knownToBeNonOverlapping(item.first, startLength->first, startLength->second))
keysToErase.insert(item.first);
for (YulString const& key: keysToErase)
m_memory.eraseKey(key);
if (sideEffects.invalidatesStorage())
m_storage.clear();
return;
}
if (sideEffects.invalidatesStorage())
m_storage.clear();
if (sideEffects.invalidatesMemory())
@ -408,7 +429,7 @@ bool DataFlowAnalyzer::inScope(YulString _variableName) const
return false;
}
std::optional<pair<YulString, YulString>> DataFlowAnalyzer::isSimpleStore(
optional<pair<YulString, YulString>> DataFlowAnalyzer::isSimpleStore(
StoreLoadLocation _location,
ExpressionStatement const& _statement
) const
@ -421,6 +442,44 @@ std::optional<pair<YulString, YulString>> DataFlowAnalyzer::isSimpleStore(
return {};
}
optional<pair<YulString, YulString>> DataFlowAnalyzer::isMemoryAreaStore(
Expression const& _expression
) const
{
FunctionCall const* funCall = get_if<FunctionCall>(&_expression);
EVMDialect const* evmDialect = dynamic_cast<EVMDialect const*>(&m_dialect);
if (!evmDialect || !funCall)
return {};
// TODO ensure that all arguments are side-effect free.
array<size_t, 2> startLength = {};
YulString name = funCall->functionName.name;
if (
name == "calldatacopy"_yulstring ||
name == "codecopy"_yulstring ||
name == "reurndatacopy"_yulstring
)
startLength = {0, 2};
else if (name == "extcodecopy"_yulstring)
startLength = {1, 3};
else if (name == "call"_yulstring || name == "callcode"_yulstring)
startLength = {5, 6};
else if (name == "delegatecall"_yulstring || name == "staticcall"_yulstring)
startLength = {4, 5};
else
return {};
cout << "fun: " << name.str() << endl;
Identifier const* start = std::get_if<Identifier>(&funCall->arguments.at(startLength[0]));
Identifier const* length = std::get_if<Identifier>(&funCall->arguments.at(startLength[1]));
if (start && length)
return {{start->name, length->name}};
else
return {};
}
std::optional<YulString> DataFlowAnalyzer::isSimpleLoad(
StoreLoadLocation _location,
Expression const& _expression

View File

@ -118,7 +118,7 @@ protected:
void clearKnowledgeIfInvalidated(Block const& _block);
/// Clears knowledge about storage or memory if they may be modified inside the expression.
void clearKnowledgeIfInvalidated(Expression const& _expression);
void clearKnowledgeIfInvalidated(Expression const& _expression, bool _currentlyVisiting = false);
/// Joins knowledge about storage and memory with an older point in the control-flow.
/// This only works if the current state is a direct successor of the older point,
@ -149,6 +149,14 @@ protected:
ExpressionStatement const& _statement
) const;
/// Checks if the expression writes to an area in memory like
/// `call` or `calldatacopy`.
/// If yes, returns (start, length).
/// If the arguments have any side-effects, returns nullopt.
std::optional<std::pair<YulString, YulString>> isMemoryAreaStore(
Expression const& _expression
) const;
/// Checks if the expression is sload(a) / mload(a)
/// where a is a variable and returns the variable in that case.
std::optional<YulString> isSimpleLoad(

View File

@ -35,6 +35,19 @@ using namespace std;
using namespace solidity;
using namespace solidity::yul;
bool KnowledgeBase::knownToBeZero(YulString _a)
{
if (!m_variableValues.count(_a) || !m_variableValues.at(_a).value)
return false;
Expression const& expr = *m_variableValues.at(_a).value;
if (!holds_alternative<Literal>(expr))
return false;
u256 val = valueOfLiteral(std::get<Literal>(expr));
return val == 0;
}
bool KnowledgeBase::knownToBeDifferent(YulString _a, YulString _b)
{
// Try to use the simplification rules together with the
@ -67,6 +80,21 @@ bool KnowledgeBase::knownToBeDifferentByAtLeast32(YulString _a, YulString _b)
return false;
}
bool KnowledgeBase::knownToBeNonOverlapping(YulString _address, YulString _start, YulString _length)
{
cout << "Overlap check: " << _address.str() << " in " << _start.str() << " - " << _length.str() << endl;
if (knownToBeZero(_length))
return true;
(void)_address;
(void)_start;
// TODO extend this by trying to simplify:
// _address + 31 < _start (what about overflow?)
// _start - _address is a number larger than 31 (what about overflow?)
return false;
}
Expression KnowledgeBase::simplify(Expression _expression)
{
bool startedRecursion = (m_recursionCounter == 0);

View File

@ -45,8 +45,11 @@ public:
m_variableValues(_variableValues)
{}
bool knownToBeZero(YulString _a);
bool knownToBeDifferent(YulString _a, YulString _b);
bool knownToBeDifferentByAtLeast32(YulString _a, YulString _b);
/// @returns true if _length is zero or _address + 32 <= _start or _address > _start + _length
bool knownToBeNonOverlapping(YulString _address, YulString _start, YulString _length);
bool knownToBeEqual(YulString _a, YulString _b) const { return _a == _b; }
private:

View File

@ -0,0 +1,58 @@
{
mstore(64, 128)
if callvalue() { revert(0, 0) }
let _1 := 0
// The following two statements destroy the
// knowledge about mload(64)
//calldatacopy(128, _1, calldatasize())
//mstore(add(128, calldatasize()), _1)
pop(delegatecall(gas(), loadimmutable("2"), 128, calldatasize(), _1, _1))
let data := _1
switch returndatasize()
case 0 { data := 96 }
default {
let result := and(add(returndatasize(), 63), not(31))
let memPtr := mload(64)
let newFreePtr := add(memPtr, result)
if or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr))
{
mstore(_1, shl(224, 0x4e487b71))
mstore(4, 0x41)
revert(_1, 0x24)
}
mstore(64, newFreePtr)
data := memPtr
mstore(memPtr, returndatasize())
returndatacopy(add(memPtr, 0x20), _1, returndatasize())
}
return(add(data, 0x20), mload(data))
}
// ----
// step: fullSuite
//
// {
// {
// let _1 := 128
// mstore(64, _1)
// if callvalue() { revert(0, 0) }
// let _2 := 0
// pop(delegatecall(gas(), loadimmutable("2"), _1, calldatasize(), _2, _2))
// let data := _2
// switch returndatasize()
// case 0 { data := 96 }
// default {
// let newFreePtr := add(_1, and(add(returndatasize(), 63), not(31)))
// if or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, _1))
// {
// mstore(_2, shl(224, 0x4e487b71))
// mstore(4, 0x41)
// revert(_2, 0x24)
// }
// mstore(64, newFreePtr)
// data := _1
// mstore(_1, returndatasize())
// returndatacopy(160, _2, returndatasize())
// }
// return(add(data, 0x20), mload(data))
// }
// }

View File

@ -0,0 +1,23 @@
{
mstore(0x40, 7)
let d := staticcall(10000, 10, 0, 200, 0, 0)
sstore(0, mload(0x40))
mstore(0x80, 7)
calldatacopy(0, 0, 0)
sstore(0, mload(0x80))
}
// ====
// EVMVersion: >=byzantium
// ----
// step: loadResolver
//
// {
// let _1 := 7
// mstore(0x40, _1)
// let _3 := 0
// pop(staticcall(10000, 10, _3, 200, _3, _3))
// sstore(_3, _1)
// mstore(0x80, _1)
// calldatacopy(_3, _3, _3)
// sstore(_3, _1)
// }