計算機科学のブログ

楕円曲線暗号 ビットコイン用の曲線の定義 secp2546k1の計算 有限体の乗算とべき演算、点のスカラー倍算、アルゴリズムの変更、任意の大きさの整数、mathインポートパス、bigパッケージ、Int型

プログラミング・ビットコイン ―ゼロからビットコインをプログラムする方法 (Jimmy Song(著)、中川 卓俊(監修)、住田 和則(監修)、中村 昭雄(監修)、星野 靖子(翻訳)、オライリー・ジャパン)の3章(楕円曲線暗号)、3.9(ビットコイン用の曲線の定義)、3.9.1(secp256k1の計算)に出てくる大きさの整数での有限体の乗算、べき演算、点のスカラー倍算の計算速度が遅くてテストが終わらなかったから、それぞれのアルゴリズムを変更、コードを修正。

コード

ecc.go

// Package ecc (Elliptic Curve Cryptography, 楕円曲線暗号)
package ecc

import (
	"fmt"
	"math/big"
)

// FieldElement 有限体の要素
type FieldElement struct {
	Num, Prime *big.Int
}

func (fe FieldElement) String() string {
	return fmt.Sprintf("FieldElement_%v(%v)", fe.Prime, fe.Num)
}

// NewFieldElement ...
func NewFieldElement(n, p *big.Int) (FieldElement, error) {
	r1 := n.Cmp(p)
	r2 := n.Cmp(big.NewInt(0))
	if !(r1 == -1 && r2 != -1) {
		pSub1 := new(big.Int)
		pSub1.Sub(p, big.NewInt(1))
		return FieldElement{},
			fmt.Errorf("%v not in field range 0 to %v", n, pSub1)
	}
	return FieldElement{n, p}, nil
}

// Eq ...
func (fe FieldElement) Eq(fes ...FieldElement) bool {
	xPrime := fe.Prime
	if xPrime.Cmp(big.NewInt(0)) == 0 {
		return false
	}
	xNum := fe.Num
	for _, y := range fes {
		if xNum.Cmp(y.Num) != 0 || xPrime.Cmp(y.Prime) != 0 {
			return false
		}
	}
	return true
}

// Ne ...
func (fe FieldElement) Ne(fes ...FieldElement) bool {
	return !fe.Eq(fes...)
}

// Add ...
func (fe FieldElement) Add(fes ...FieldElement) (FieldElement, error) {
	t := new(big.Int)
	t.Set(fe.Num)
	for _, fei := range fes {
		if fe.Prime.Cmp(fei.Prime) != 0 {
			return FieldElement{}, fmt.Errorf("Cannot add two numbers in different Fields")
		}
		t.Add(t, fei.Num)
		t.Mod(t, fei.Prime)
	}
	return FieldElement{t, fe.Prime}, nil
}

// Sub ...
func (fe FieldElement) Sub(fes ...FieldElement) (FieldElement, error) {
	t := new(big.Int)
	t.Set(fe.Num)
	for _, fey := range fes {
		if fe.Prime.Cmp(fey.Prime) != 0 {
			return FieldElement{}, fmt.Errorf("Cannot sub two numbers in different Fields")
		}
		t.Sub(t, fey.Num)
		t.Mod(t, fey.Prime)
	}
	return FieldElement{Num: t, Prime: fe.Prime}, nil
}

// Mul ...
func (fe FieldElement) Mul(fes ...FieldElement) (FieldElement, error) {
	if fe.Num.Cmp(big.NewInt(0)) == 0 {
		return fe, nil
	}
	cur := new(big.Int)
	cur.Set(fe.Num)
	zero := big.NewInt(0)
	for _, fey := range fes {
		if fe.Prime.Cmp(fey.Prime) != 0 {
			return FieldElement{}, fmt.Errorf("Cannot mul two numbers in different Fields")
		}
		if fey.Num.Cmp(zero) == 0 {
			return fe, nil
		}
		cur.Mul(cur, fey.Num)
		cur.Mod(cur, fe.Prime)
	}
	return NewFieldElement(cur, fe.Prime)
}

// Div ...
func (fe FieldElement) Div(fey FieldElement) (FieldElement, error) {
	if fe.Prime.Cmp(fey.Prime) != 0 {
		return FieldElement{},
			fmt.Errorf("Cannot div two numbers in different Fields")
	}
	if fey.Num.Cmp(big.NewInt(0)) == 0 {
		return FieldElement{},
			fmt.Errorf("division by zero")
	}
	y := new(big.Int)
	y.Sub(fey.Prime, big.NewInt(2))
	return fe.Mul(fey.Pow(y))
}

// Pow ...
func (fe FieldElement) Pow(y *big.Int) FieldElement {
	if y.Cmp(big.NewInt(0)) < 0 {
		y.Add(y, fe.Prime)
		y.Sub(y, big.NewInt(1))
	}
	if y.Cmp(big.NewInt(0)) == 0 {
		t, _ := NewFieldElement(big.NewInt(0), fe.Prime)
		return t
	}
	base := new(big.Int)
	base.Set(fe.Num)
	n := big.NewInt(1)
	one := big.NewInt(1)
	two := big.NewInt(2)
	exp := new(big.Int)
	exp.Set(y)
	for exp.Cmp(one) != 0 {
		r := new(big.Int)
		r.Set(exp)
		r.Mod(exp, two)
		if r.Cmp(one) == 0 {
			n.Mul(n, base)
			exp.Sub(exp, one)
		}
		base.Mul(base, base)
		base.Mod(base, fe.Prime)
		exp.Div(exp, two)
	}
	n.Mul(n, base)
	n.Mod(n, fe.Prime)
	return FieldElement{Num: n, Prime: fe.Prime}
}

// Point 有限体上の楕円曲線 y^2 = x^3 + Ax + B の点
type Point struct {
	A, B     FieldElement
	Infinity bool
	X, Y     FieldElement
}

// NewPoint ...
func NewPoint(inf bool, x, y, a, b FieldElement) (Point, error) {
	for _, fe := range []FieldElement{y, a, b} {
		if x.Prime.Cmp(fe.Prime) != 0 {
			return Point{}, fmt.Errorf("(%v, %v) in different Fields", x, fe)
		}
	}
	if inf {
		return Point{Infinity: true, A: a, B: b}, nil
	}
	l, _ := y.Mul(y)
	r1, _ := x.Mul(x, x)
	r2, _ := a.Mul(x)
	r, _ := r1.Add(r2, b)
	if l.Ne(r) {
		return Point{}, fmt.Errorf("(%v, %v) is not on the curve", x, y)
	}
	return Point{A: a, B: b, X: x, Y: y}, nil
}

func (p Point) String() string {
	if p.Infinity {
		return fmt.Sprintf("Point(Infinity)_%v_%v FieldElement(%v)",
			p.A.Num, p.B.Num, p.A.Prime)
	}
	return fmt.Sprintf("Point(%v,%v)_%v_%v FieldElement(%v)",
		p.X.Num, p.Y.Num, p.A.Num, p.B.Num, p.A.Prime)
}

// Eq ...
func (p Point) Eq(py Point) bool {
	if p.A.Eq(py.A) && p.B.Eq(py.B) {
		if p.Infinity && py.Infinity {
			return true
		}
		if p.Infinity || py.Infinity {
			return false
		}
		return p.X.Eq(py.X) && p.Y.Eq(py.Y)
	}
	return false
}

// Ne ...
func (p Point) Ne(py Point) bool {
	return !p.Eq(py)
}

// Add ...
func (p Point) Add(py Point) (Point, error) {
	if p.A.Ne(py.A) || p.B.Ne(py.B) {
		return Point{},
			fmt.Errorf("Points %v, %v are not on the same curve", p, py)
	}
	if p.Infinity {
		return py, nil
	}
	if py.Infinity {
		return p, nil
	}
	zero, _ := NewFieldElement(big.NewInt(0), p.A.Prime)
	if p.X.Eq(py.X) && p.Y.Ne(py.Y) {
		return NewPoint(true, zero, zero, p.A, p.B)
	}
	if p.Eq(py) && p.Y.Eq(zero) {
		return NewPoint(true, zero, zero, p.A, p.B)
	}
	if p.Eq(py) {
		n, err := NewFieldElement(big.NewInt(3), p.A.Prime)
		if err != nil {
			return Point{}, err
		}
		n, _ = n.Mul(p.X, p.X)
		n, err = n.Add(p.A)
		if err != nil {
			return Point{}, err
		}
		d, err := NewFieldElement(big.NewInt(2), p.A.Prime)
		if err != nil {
			return Point{}, err
		}
		d, err = d.Mul(p.Y)
		if err != nil {
			return Point{}, err
		}
		s, err := n.Div(d)
		if err != nil {
			return Point{}, err
		}
		l, _ := s.Mul(s)
		r, err := NewFieldElement(big.NewInt(2), p.A.Prime)
		if err != nil {
			return Point{}, err
		}
		r, err = r.Mul(p.X)
		if err != nil {
			return Point{}, err
		}
		x, err := l.Sub(r)
		if err != nil {
			return Point{}, err
		}
		y, err := p.X.Sub(x)
		if err != nil {
			return Point{}, err
		}
		y, err = s.Mul(y)
		if err != nil {
			return Point{}, err
		}
		y, err = y.Sub(p.Y)
		if err != nil {
			return Point{}, err
		}
		return NewPoint(false, x, y, p.A, p.B)
	}
	n, err := py.Y.Sub(p.Y)
	if err != nil {
		return Point{}, err
	}
	d, err := py.X.Sub(p.X)
	if err != nil {
		return Point{}, err
	}
	s, err := n.Div(d)
	if err != nil {
		return Point{}, err
	}
	x, _ := s.Mul(s)
	x, err = x.Sub(p.X, py.X)
	if err != nil {
		return Point{}, err
	}
	y, err := p.X.Sub(x)
	if err != nil {
		return Point{}, err
	}
	y, err = s.Mul(y)
	if err != nil {
		return Point{}, err
	}
	y, err = y.Sub(p.Y)
	if err != nil {
		return Point{}, err
	}
	return NewPoint(false, x, y, p.A, p.B)
}

// ScalarMul ...
func (p Point) ScalarMul(s *big.Int) Point {
	z, _ := NewFieldElement(big.NewInt(0), p.A.Prime)
	sp, _ := NewPoint(true, z, z, p.A, p.B)
	coef := new(big.Int)
	coef.Set(s)
	cur := p
	for _, word := range coef.Bits() {
		for word != 0 {
			if word&1 != 0 {
				sp, _ = sp.Add(cur)
			}
			cur, _ = cur.Add(cur)
			word >>= 1
		}
	}
	return sp
}

ecc_test.go

package ecc

import (
	"math/big"
	"testing"
)

func TestFieldElementEq(t *testing.T) {
	p := big.NewInt(13)
	a, _ := NewFieldElement(big.NewInt(7), p)
	b, _ := NewFieldElement(big.NewInt(6), p)
	got := a.Eq(b)
	want := false
	if got {
		t.Errorf("%v.Eq(%v) got %v, want false", a, b, got)
	}
	got = a.Eq(a)
	want = true
	if !got {
		t.Errorf("%v.Eq(%v) got %v, want true", a, a, want)
	}
}

func TestFieldElementNe(t *testing.T) {
	p := big.NewInt(13)
	a, _ := NewFieldElement(big.NewInt(7), p)
	b, _ := NewFieldElement(big.NewInt(6), p)
	got := a.Ne(b)
	if !got {
		t.Errorf("%v.Ne(%v) got %v, want true", a, b, got)
	}
	got = a.Ne(a)
	if got {
		t.Errorf("%v.Ne(%v) got %v, want false", a, a, got)
	}
}

func TestFieldElementAdd(t *testing.T) {
	p := big.NewInt(13)
	a, _ := NewFieldElement(big.NewInt(7), p)
	b, _ := NewFieldElement(big.NewInt(12), p)
	c, _ := NewFieldElement(big.NewInt(6), p)
	d, _ := a.Add(b)
	if d.Ne(c) {
		t.Errorf("%v.Add(%v) got %v, want %v", a, b, d, c)
	}
}

func TestFieldElementSub(t *testing.T) {
	p := big.NewInt(19)
	tests := []struct {
		x, y, want int64
	}{
		{11, 9, 2},
		{6, 13, 12},
	}
	for _, test := range tests {
		x, _ := NewFieldElement(big.NewInt(test.x), p)
		y, _ := NewFieldElement(big.NewInt(test.y), p)
		got, _ := x.Sub(y)
		want, _ := NewFieldElement(big.NewInt(test.want), p)
		if got.Ne(want) {
			t.Errorf("%v.Sub(%v) got %v, want %v", x, y, got, want)
		}
	}
}

func TestFieldElementMul(t *testing.T) {
	p := big.NewInt(13)
	x, _ := NewFieldElement(big.NewInt(3), p)
	y, _ := NewFieldElement(big.NewInt(12), p)
	got, _ := x.Mul(y)
	want, _ := NewFieldElement(big.NewInt(10), p)
	if got.Ne(want) {
		t.Errorf("%v.Mul(%v) got %v, want %v", x, y, got, want)
	}
}

func TestFieldElementDiv(t *testing.T) {
	p := big.NewInt(19)
	tests := []struct {
		num, den, want int64
	}{
		{2, 7, 3},
		{7, 5, 9},
	}
	for _, test := range tests {
		num, _ := NewFieldElement(big.NewInt(test.num), p)
		den, _ := NewFieldElement(big.NewInt(test.den), p)
		got, _ := num.Div(den)
		want, _ := NewFieldElement(big.NewInt(test.want), p)
		if got.Ne(want) {
			t.Errorf("%v.Div(%v) got %v, want %v", num, den, got, want)
		}
	}
}

func TestFieldElementPow(t *testing.T) {
	p := big.NewInt(13)
	tests := []struct {
		x, y, want int64
	}{
		{3, 3, 1},
		{7, -3, 8},
	}
	for _, test := range tests {
		x, _ := NewFieldElement(big.NewInt(test.x), p)
		got := x.Pow(big.NewInt(test.y))
		want, _ := NewFieldElement(big.NewInt(test.want), p)
		if got.Ne(want) {
			t.Errorf("%v.Pow(%v) got %v, want %v", x, test.y, got, want)
		}
	}
}

func TestPointString(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	x, _ := NewFieldElement(big.NewInt(192), p)
	y, _ := NewFieldElement(big.NewInt(105), p)
	point, _ := NewPoint(false, x, y, a, b)
	want := "Point(192,105)_0_7 FieldElement(223)"
	got := point.String()
	if got != want {
		t.Errorf("%v.String() got %v, want %v", point, got, want)
	}
}

func TestPointOnCurve(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	validPoints := []struct {
		x, y int64
	}{
		{192, 105},
		{17, 56},
		{1, 193},
	}
	for _, xy := range validPoints {
		x, _ := NewFieldElement(big.NewInt(xy.x), p)
		y, _ := NewFieldElement(big.NewInt(xy.y), p)
		point, err := NewPoint(false, x, y, a, b)
		if err != nil {
			t.Errorf("%v is not on curve", point)
		}
	}
	invalidPoints := []struct {
		x, y int64
	}{
		{200, 119},
		{42, 99},
	}
	for _, xy := range invalidPoints {
		x, _ := NewFieldElement(big.NewInt(xy.x), p)
		y, _ := NewFieldElement(big.NewInt(xy.y), p)
		point, err := NewPoint(false, x, y, a, b)
		if err == nil {
			t.Errorf("%v is on curve", point)
		}
	}
}

func TestPointAdd(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	tests := []struct {
		x1, y1, x2, y2, x3, y3 int64
	}{
		{170, 142, 60, 139, 220, 181},
		{47, 71, 17, 56, 215, 68},
		{143, 98, 76, 66, 47, 71},
	}
	for _, test := range tests {
		x1, _ := NewFieldElement(big.NewInt(test.x1), p)
		y1, _ := NewFieldElement(big.NewInt(test.y1), p)
		p1, _ := NewPoint(false, x1, y1, a, b)
		x2, _ := NewFieldElement(big.NewInt(test.x2), p)
		y2, _ := NewFieldElement(big.NewInt(test.y2), p)
		p2, _ := NewPoint(false, x2, y2, a, b)
		got, _ := p1.Add(p2)
		x3, _ := NewFieldElement(big.NewInt(test.x3), p)
		y3, _ := NewFieldElement(big.NewInt(test.y3), p)
		want, _ := NewPoint(false, x3, y3, a, b)
		if got.Ne(want) {
			t.Errorf("%v.Add(%v) got %v, want %v", p1, p2, got, want)
		}
	}
}

func TestPointScalarMul(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	x, _ := NewFieldElement(big.NewInt(47), p)
	y, _ := NewFieldElement(big.NewInt(71), p)
	point, _ := NewPoint(false, x, y, a, b)
	tests := []struct {
		s, x, y int64
	}{
		{2, 36, 111},
		{4, 194, 51},
		{8, 116, 55},
	}
	for _, test := range tests {
		got := point.ScalarMul(big.NewInt(test.s))
		sx, _ := NewFieldElement(big.NewInt(test.x), p)
		sy, _ := NewFieldElement(big.NewInt(test.y), p)
		want, _ := NewPoint(false, sx, sy, a, b)
		if got.Ne(want) {
			t.Errorf("%v.ScalarMul(%v) got %v, want %v", point, test.s, got, want)
		}
	}
	s := big.NewInt(21)
	got := point.ScalarMul(s)
	z, _ := NewFieldElement(big.NewInt(0), p)
	want, _ := NewPoint(true, z, z, a, b)
	if got.Ne(want) {
		t.Errorf("%v.ScalarMul(%v) got %v, want %v", point, s, got, want)
	}
}

func TestBitcoinPointOnCurve(t *testing.T) {
	p := new(big.Int)
	p.SetString("115792089237316195423570985008687907853269984665640564039457584007908834671663", 10)
	a, err := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	x := new(big.Int)
	x.SetString("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798", 16)
	y := new(big.Int)
	y.SetString("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8", 16)
	gx, _ := NewFieldElement(x, p)
	gy, _ := NewFieldElement(y, p)
	g, err := NewPoint(false, gx, gy, a, b)
	if err != nil {
		t.Errorf("%v is not on curve", g)
	}
	n := new(big.Int)
	n.SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16)
	g = g.ScalarMul(n)
	if !g.Infinity {
		t.Errorf("%v is not Infinity", g)
	}
}

入出力結果

% go test
PASS
ok  	bitcoin/ecc	0.069s
% go test
--- FAIL: TestFieldElementPow (0.00s)
    ecc_test.go:111: FieldElement_13(3).Pow(3) got FieldElement_13(27), want FieldElement_13(1)
    ecc_test.go:111: FieldElement_13(7).Pow(-3) got FieldElement_13(21), want FieldElement_13(8)
FAIL
exit status 1
FAIL	bitcoin/ecc	0.220s
% go test
PASS
ok  	bitcoin/ecc	0.295s
% go test
PASS
ok  	bitcoin/ecc	0.592s
%