package net

import (
	"context"
	"net"

	"golang.org/x/xerrors"

	logging "github.com/ipfs/go-log/v2"
	manet "github.com/multiformats/go-multiaddr/net"

	"github.com/filecoin-project/lotus/api"
)

var cLog = logging.Logger("conngater")

func (a *NetAPI) NetBlockAdd(ctx context.Context, acl api.NetBlockList) error {
	for _, p := range acl.Peers {
		err := a.ConnGater.BlockPeer(p)
		if err != nil {
			return xerrors.Errorf("error blocking peer %s: %w", p, err)
		}

		for _, c := range a.Host.Network().ConnsToPeer(p) {
			err = c.Close()
			if err != nil {
				// just log this, don't fail
				cLog.Warnf("error closing connection to %s: %s", p, err)
			}
		}
	}

	for _, addr := range acl.IPAddrs {
		ip := net.ParseIP(addr)
		if ip == nil {
			return xerrors.Errorf("error parsing IP address %s", addr)
		}

		err := a.ConnGater.BlockAddr(ip)
		if err != nil {
			return xerrors.Errorf("error blocking IP address %s: %w", addr, err)
		}

		for _, c := range a.Host.Network().Conns() {
			remote := c.RemoteMultiaddr()
			remoteIP, err := manet.ToIP(remote)
			if err != nil {
				continue
			}

			if ip.Equal(remoteIP) {
				err = c.Close()
				if err != nil {
					// just log this, don't fail
					cLog.Warnf("error closing connection to %s: %s", remoteIP, err)
				}
			}
		}
	}

	for _, subnet := range acl.IPSubnets {
		_, cidr, err := net.ParseCIDR(subnet)
		if err != nil {
			return xerrors.Errorf("error parsing subnet %s: %w", subnet, err)
		}

		err = a.ConnGater.BlockSubnet(cidr)
		if err != nil {
			return xerrors.Errorf("error blocking subunet %s: %w", subnet, err)
		}

		for _, c := range a.Host.Network().Conns() {
			remote := c.RemoteMultiaddr()
			remoteIP, err := manet.ToIP(remote)
			if err != nil {
				continue
			}

			if cidr.Contains(remoteIP) {
				err = c.Close()
				if err != nil {
					// just log this, don't fail
					cLog.Warnf("error closing connection to %s: %s", remoteIP, err)
				}
			}
		}
	}

	return nil
}

func (a *NetAPI) NetBlockRemove(ctx context.Context, acl api.NetBlockList) error {
	for _, p := range acl.Peers {
		err := a.ConnGater.UnblockPeer(p)
		if err != nil {
			return xerrors.Errorf("error unblocking peer %s: %w", p, err)
		}
	}

	for _, addr := range acl.IPAddrs {
		ip := net.ParseIP(addr)
		if ip == nil {
			return xerrors.Errorf("error parsing IP address %s", addr)
		}

		err := a.ConnGater.UnblockAddr(ip)
		if err != nil {
			return xerrors.Errorf("error unblocking IP address %s: %w", addr, err)
		}
	}

	for _, subnet := range acl.IPSubnets {
		_, cidr, err := net.ParseCIDR(subnet)
		if err != nil {
			return xerrors.Errorf("error parsing subnet %s: %w", subnet, err)
		}

		err = a.ConnGater.UnblockSubnet(cidr)
		if err != nil {
			return xerrors.Errorf("error unblocking subunet %s: %w", subnet, err)
		}
	}

	return nil
}

func (a *NetAPI) NetBlockList(ctx context.Context) (result api.NetBlockList, err error) {
	result.Peers = a.ConnGater.ListBlockedPeers()
	for _, ip := range a.ConnGater.ListBlockedAddrs() {
		result.IPAddrs = append(result.IPAddrs, ip.String())
	}
	for _, subnet := range a.ConnGater.ListBlockedSubnets() {
		result.IPSubnets = append(result.IPSubnets, subnet.String())
	}
	return
}