@@ -2,6 +2,8 @@ package group
22
33import (
44 "crypto"
5+ "crypto/ecdh"
6+ "crypto/ecdsa"
57 "crypto/elliptic"
68 _ "crypto/sha256"
79 _ "crypto/sha512"
@@ -226,9 +228,20 @@ func (e *wElt) MarshalBinary() ([]byte, error) {
226228 if e .IsIdentity () {
227229 return []byte {0x0 }, nil
228230 }
231+
232+ ec := e .wG .c
233+ if e .wG == P384 {
234+ ec = elliptic .P384 ()
235+ }
236+
229237 e .x .Mod (e .x , e .c .Params ().P )
230238 e .y .Mod (e .y , e .c .Params ().P )
231- return elliptic .Marshal (e .wG .c , e .x , e .y ), nil
239+ pk , err := (& ecdsa.PublicKey {Curve : ec , X : e .x , Y : e .y }).ECDH ()
240+ if err != nil {
241+ return nil , err
242+ }
243+
244+ return pk .Bytes (), nil
232245}
233246
234247func (e * wElt ) MarshalBinaryCompress () ([]byte , error ) {
@@ -241,7 +254,8 @@ func (e *wElt) MarshalBinaryCompress() ([]byte, error) {
241254}
242255
243256func (e * wElt ) UnmarshalBinary (b []byte ) error {
244- byteLen := (e .c .Params ().BitSize + 7 ) / 8
257+ bitSize := e .c .Params ().BitSize
258+ byteLen := (bitSize + 7 ) / 8
245259 l := len (b )
246260 switch {
247261 case l == 1 && b [0 ] == 0x00 : // point at infinity
@@ -254,11 +268,22 @@ func (e *wElt) UnmarshalBinary(b []byte) error {
254268 }
255269 e .x , e .y = x , y
256270 case l == 1 + 2 * byteLen && b [0 ] == 0x04 : // uncompressed
257- x , y := elliptic .Unmarshal (e .wG .c , b )
258- if x == nil {
271+ var err error
272+ switch bitSize {
273+ case 256 :
274+ _ , err = ecdh .P256 ().NewPublicKey (b )
275+ case 384 :
276+ _ , err = ecdh .P384 ().NewPublicKey (b )
277+ case 521 :
278+ _ , err = ecdh .P521 ().NewPublicKey (b )
279+ }
280+
281+ if err != nil {
259282 return ErrUnmarshal
260283 }
261- e .x , e .y = x , y
284+
285+ e .x .SetBytes (b [1 : 1 + byteLen ])
286+ e .y .SetBytes (b [1 + byteLen : 1 + 2 * byteLen ])
262287 default :
263288 return ErrUnmarshal
264289 }
0 commit comments