diff --git a/ecies.go b/ecies.go index 0e2403d47..18952fc0b 100644 --- a/ecies.go +++ b/ecies.go @@ -13,11 +13,12 @@ import ( ) var ( - ErrImport = fmt.Errorf("ecies: failed to import key") - ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve") - ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters") - ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key") - ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key is too big") + ErrImport = fmt.Errorf("ecies: failed to import key") + ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve") + ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters") + ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key") + ErrSharedKeyIsPointAtInfinity = fmt.Errorf("ecies: shared key is point at infinity") + ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key params are too big") ) // PublicKey is a representation of an elliptic curve public key. @@ -90,16 +91,20 @@ func MaxSharedKeyLength(pub *PublicKey) int { // ECDH key agreement method used to establish secret keys for encryption. func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []byte, err error) { if prv.PublicKey.Curve != pub.Curve { - err = ErrInvalidCurve - return + return nil, ErrInvalidCurve + } + if skLen+macLen > MaxSharedKeyLength(pub) { + return nil, ErrSharedKeyTooBig } x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes()) - if x == nil || (x.BitLen()+7)/8 < (skLen+macLen) { - err = ErrSharedKeyTooBig - return + if x == nil { + return nil, ErrSharedKeyIsPointAtInfinity } - sk = x.Bytes()[:skLen+macLen] - return + + sk = make([]byte, skLen+macLen) + skBytes := x.Bytes() + copy(sk[len(sk)-len(skBytes):], skBytes) + return sk, nil } var (