diff --git a/packages/codegen/src/generate-code.ts b/packages/codegen/src/generate-code.ts index f36474a6..142d5198 100644 --- a/packages/codegen/src/generate-code.ts +++ b/packages/codegen/src/generate-code.ts @@ -14,6 +14,7 @@ import os from 'os'; import { flatten } from '@poanet/solidity-flattener'; import { parse, visit } from '@solidity-parser/parser'; +import { ASTNode } from '@solidity-parser/parser/dist/src/ast-types'; import { KIND_ACTIVE, KIND_LAZY } from '@cerc-io/util'; import { MODE_ETH_CALL, MODE_STORAGE, MODE_ALL, MODE_NONE, DEFAULT_PORT } from './utils/constants'; @@ -38,7 +39,7 @@ import { getSubgraphConfig } from './utils/subgraph'; import { exportIndexBlock } from './index-block'; import { exportSubscriber } from './subscriber'; import { exportReset } from './reset'; -import { writeFileToStream } from './utils/helpers'; +import { filterInheritedContractNodes, writeFileToStream } from './utils/helpers'; const ASSET_DIR = path.resolve(__dirname, 'assets'); @@ -146,8 +147,14 @@ function parseAndVisit (visitor: Visitor, contracts: any[], mode: string) { // Get the abstract syntax tree for the flattened contract. const ast = parse(contract.contractString); - // Filter out library nodes. - ast.children = ast.children.filter(child => !(child.type === 'ContractDefinition' && child.kind === 'library')); + const contractNode = ast.children.find((node: ASTNode) => + node.type === 'ContractDefinition' && + node.name === contract.contractName + ); + + assert(contractNode); + const nodes = filterInheritedContractNodes(ast, [contractNode]); + ast.children = Array.from(nodes).concat(contractNode); visit(ast, { StateVariableDeclaration: stateVariableDeclarationVisitor, diff --git a/packages/codegen/src/utils/helpers.ts b/packages/codegen/src/utils/helpers.ts index 6bae01ee..78513c5a 100644 --- a/packages/codegen/src/utils/helpers.ts +++ b/packages/codegen/src/utils/helpers.ts @@ -4,7 +4,8 @@ import fs from 'fs'; import { Writable } from 'stream'; -import { TypeName } from '@solidity-parser/parser/dist/src/ast-types'; + +import { TypeName, ASTNode, InheritanceSpecifier, SourceUnit } from '@solidity-parser/parser/dist/src/ast-types'; export const isArrayType = (typeName: TypeName): boolean => (typeName.type === 'ArrayTypeName'); @@ -22,3 +23,38 @@ export function writeFileToStream (pathToFile: string, outStream: Writable): voi const fileStream = fs.createReadStream(pathToFile); fileStream.pipe(outStream); } + +/** + * Get inherited contracts for array of contractNodes + * @param ast + * @param contractNodes + */ +export function filterInheritedContractNodes (ast: SourceUnit, contractNodes: ASTNode[]): Set { + const resultSet: Set = new Set(); + + contractNodes.forEach((node: ASTNode) => { + if (node.type !== 'ContractDefinition') { + return; + } + + // Filter out library nodes + if (node.kind === 'library') { + return; + } + + const inheritedContracts = ast.children.filter((childNode: ASTNode) => + node.baseContracts.some((baseContract: InheritanceSpecifier) => + childNode.type === 'ContractDefinition' && baseContract.baseName.namePath === childNode.name + ) + ); + + // Add inherited contracts to result set + inheritedContracts.forEach((node: ASTNode) => resultSet.add(node)); + // Get parent inherited contracts + const parentInheritedNodes = filterInheritedContractNodes(ast, inheritedContracts); + // Add parent inherited contract nodes in result set + parentInheritedNodes.forEach((node: ASTNode) => resultSet.add(node)); + }); + + return resultSet; +}