計算機科学のブログ

楕円曲線暗号 ビットコイン用の曲線の定義 secp2546k1の計算 有限体における点の加算、スカラー倍算のコーディング、任意の大きさの整数、mathインポートパス、bigパッケージ、Int型

プログラミング・ビットコイン ―ゼロからビットコインをプログラムする方法 (Jimmy Song(著)、中川 卓俊(監修)、住田 和則(監修)、中村 昭雄(監修)、星野 靖子(翻訳)、オライリー・ジャパン)の3章(楕円曲線暗号)、3.9(ビットコイン用の曲線の定義)、3.9.1(secp256k1の計算)に出てくる大きさの整数を扱えるように有限体における点の加算、スカラー倍算のコードを修正。

コード

ecc.go

// Package ecc (Elliptic Curve Cryptography, 楕円曲線暗号)
package ecc

import (
	"fmt"
	"math/big"
)

// FieldElement 有限体の要素
type FieldElement struct {
	Num, Prime *big.Int
}

func (fe FieldElement) String() string {
	return fmt.Sprintf("FieldElement_%v(%v)", fe.Prime, fe.Num)
}

// NewFieldElement ...
func NewFieldElement(n, p *big.Int) (FieldElement, error) {
	r1 := n.Cmp(p)
	r2 := n.Cmp(big.NewInt(0))
	if !(r1 == -1 && r2 != -1) {
		pSub1 := new(big.Int)
		pSub1.Sub(p, big.NewInt(1))
		return FieldElement{},
			fmt.Errorf("%v not in field range 0 to %v", n, pSub1)
	}
	return FieldElement{n, p}, nil
}

// Eq ...
func (fe FieldElement) Eq(fes ...FieldElement) bool {
	xPrime := fe.Prime
	if xPrime.Cmp(big.NewInt(0)) == 0 {
		return false
	}
	xNum := fe.Num
	for _, y := range fes {
		if xNum.Cmp(y.Num) != 0 || xPrime.Cmp(y.Prime) != 0 {
			return false
		}
	}
	return true
}

// Ne ...
func (fe FieldElement) Ne(fes ...FieldElement) bool {
	return !fe.Eq(fes...)
}

// Add ...
func (fe FieldElement) Add(fes ...FieldElement) (FieldElement, error) {
	t := new(big.Int)
	t.Set(fe.Num)
	for _, fei := range fes {
		if fe.Prime.Cmp(fei.Prime) != 0 {
			return FieldElement{}, fmt.Errorf("Cannot add two numbers in different Fields")
		}
		t.Add(t, fei.Num)
		t.Mod(t, fei.Prime)
	}
	return FieldElement{t, fe.Prime}, nil
}

// Sub ...
func (fe FieldElement) Sub(fes ...FieldElement) (FieldElement, error) {
	t := new(big.Int)
	t.Set(fe.Num)
	for _, fey := range fes {
		if fe.Prime.Cmp(fey.Prime) != 0 {
			return FieldElement{}, fmt.Errorf("Cannot sub two numbers in different Fields")
		}
		t.Sub(t, fey.Num)
		t.Mod(t, fey.Prime)
	}
	return FieldElement{Num: t, Prime: fe.Prime}, nil
}

// Mul ...
func (fe FieldElement) Mul(fes ...FieldElement) (FieldElement, error) {
	if fe.Num.Cmp(big.NewInt(0)) == 0 {
		return fe, nil
	}
	n := new(big.Int)
	n.Set(fe.Num)
	t, _ := NewFieldElement(n, fe.Prime)
	for _, fey := range fes {
		if fe.Prime.Cmp(fey.Prime) != 0 {
			return FieldElement{}, fmt.Errorf("Cannot mul two numbers in different Fields")
		}
		if fey.Num.Cmp(big.NewInt(0)) == 0 {
			return fey, nil
		}
		s := t
		one := big.NewInt(1)
		for i := big.NewInt(1); i.Cmp(fey.Num) == -1; i.Add(i, one) {
			t, _ = t.Add(s)
		}
	}
	return t, nil
}

// Div ...
func (fe FieldElement) Div(fey FieldElement) (FieldElement, error) {
	if fe.Prime.Cmp(fey.Prime) != 0 {
		return FieldElement{},
			fmt.Errorf("Cannot div two numbers in different Fields")
	}
	if fey.Num.Cmp(big.NewInt(0)) == 0 {
		return FieldElement{},
			fmt.Errorf("division by zero")
	}
	y := new(big.Int)
	y.Sub(fey.Prime, big.NewInt(2))
	return fe.Mul(fey.Pow(y))
}

// Pow ...
func (fe FieldElement) Pow(y *big.Int) FieldElement {
	if y.Cmp(big.NewInt(0)) < 0 {
		y.Add(y, fe.Prime)
		y.Sub(y, big.NewInt(1))
	}
	if y.Cmp(big.NewInt(0)) == 0 {
		t, _ := NewFieldElement(big.NewInt(0), fe.Prime)
		return t
	}
	t := fe
	one := big.NewInt(1)
	for i := big.NewInt(1); i.Cmp(y) == -1; i.Add(i, one) {
		t, _ = t.Mul(fe)
	}
	return t
}

// Point 有限体上の楕円曲線 y^2 = x^3 + Ax + B の点
type Point struct {
	A, B     FieldElement
	Infinity bool
	X, Y     FieldElement
}

// NewPoint ...
func NewPoint(inf bool, x, y, a, b FieldElement) (Point, error) {
	for _, fe := range []FieldElement{y, a, b} {
		if x.Prime.Cmp(fe.Prime) != 0 {
			return Point{}, fmt.Errorf("(%v, %v) in different Fields", x, fe)
		}
	}
	if inf {
		return Point{Infinity: true, A: a, B: b}, nil
	}
	l, _ := y.Mul(y)
	r1, _ := x.Mul(x, x)
	r2, _ := a.Mul(x)
	r, _ := r1.Add(r2, b)
	if l.Ne(r) {
		return Point{}, fmt.Errorf("(%v, %v) is not on the curve", x, y)
	}
	return Point{A: a, B: b, X: x, Y: y}, nil
}

func (p Point) String() string {
	if p.Infinity {
		return fmt.Sprintf("Point(Infinity)_%v_%v FieldElement(%v)",
			p.A.Num, p.B.Num, p.A.Prime)
	}
	return fmt.Sprintf("Point(%v,%v)_%v_%v FieldElement(%v)",
		p.X.Num, p.Y.Num, p.A.Num, p.B.Num, p.A.Prime)
}

// Eq ...
func (p Point) Eq(py Point) bool {
	if p.A.Eq(py.A) && p.B.Eq(py.B) {
		if p.Infinity && py.Infinity {
			return true
		}
		if p.Infinity || py.Infinity {
			return false
		}
		return p.X.Eq(py.X) && p.Y.Eq(py.Y)
	}
	return false
}

// Ne ...
func (p Point) Ne(py Point) bool {
	return !p.Eq(py)
}

// Add ...
func (p Point) Add(py Point) (Point, error) {
	if p.A.Ne(py.A) || p.B.Ne(py.B) {
		return Point{},
			fmt.Errorf("Points %v, %v are not on the same curve", p, py)
	}
	if p.Infinity {
		return py, nil
	}
	if py.Infinity {
		return p, nil
	}
	zero, _ := NewFieldElement(big.NewInt(0), p.A.Prime)
	if p.X.Eq(py.X) && p.Y.Ne(py.Y) {
		return NewPoint(true, zero, zero, p.A, p.B)
	}
	if p.Eq(py) && p.Y.Eq(zero) {
		return NewPoint(true, zero, zero, p.A, p.B)
	}
	if p.Eq(py) {
		n, err := NewFieldElement(big.NewInt(3), p.A.Prime)
		if err != nil {
			return Point{}, err
		}
		n, _ = n.Mul(p.X, p.X)
		n, err = n.Add(p.A)
		if err != nil {
			return Point{}, err
		}
		d, err := NewFieldElement(big.NewInt(2), p.A.Prime)
		if err != nil {
			return Point{}, err
		}
		d, err = d.Mul(p.Y)
		if err != nil {
			return Point{}, err
		}
		s, err := n.Div(d)
		if err != nil {
			return Point{}, err
		}
		l, _ := s.Mul(s)
		r, err := NewFieldElement(big.NewInt(2), p.A.Prime)
		if err != nil {
			return Point{}, err
		}
		r, err = r.Mul(p.X)
		if err != nil {
			return Point{}, err
		}
		x, err := l.Sub(r)
		if err != nil {
			return Point{}, err
		}
		y, err := p.X.Sub(x)
		if err != nil {
			return Point{}, err
		}
		y, err = s.Mul(y)
		if err != nil {
			return Point{}, err
		}
		y, err = y.Sub(p.Y)
		if err != nil {
			return Point{}, err
		}
		return NewPoint(false, x, y, p.A, p.B)
	}
	n, err := py.Y.Sub(p.Y)
	if err != nil {
		return Point{}, err
	}
	d, err := py.X.Sub(p.X)
	if err != nil {
		return Point{}, err
	}
	s, err := n.Div(d)
	if err != nil {
		return Point{}, err
	}
	x, _ := s.Mul(s)
	x, err = x.Sub(p.X, py.X)
	if err != nil {
		return Point{}, err
	}
	y, err := p.X.Sub(x)
	if err != nil {
		return Point{}, err
	}
	y, err = s.Mul(y)
	if err != nil {
		return Point{}, err
	}
	y, err = y.Sub(p.Y)
	if err != nil {
		return Point{}, err
	}
	return NewPoint(false, x, y, p.A, p.B)
}

// ScalarMul ...
func (p Point) ScalarMul(s *big.Int) Point {
	z, _ := NewFieldElement(big.NewInt(0), p.A.Prime)
	sp, _ := NewPoint(true, z, z, p.A, p.B)
	coef := new(big.Int)
	coef.Set(s)
	cur := p
	for _, word := range coef.Bits() {
		for word != 0 {
			if word&1 != 0 {
				sp, _ = sp.Add(cur)
			}
			cur, _ = cur.Add(cur)
			word >>= 1
		}
	}
	return sp
}

ecc_test.go

package ecc

import (
	"math/big"
	"testing"
)

func TestFieldElementEq(t *testing.T) {
	p := big.NewInt(13)
	a, _ := NewFieldElement(big.NewInt(7), p)
	b, _ := NewFieldElement(big.NewInt(6), p)
	got := a.Eq(b)
	want := false
	if got {
		t.Errorf("%v.Eq(%v) got %v, want false", a, b, got)
	}
	got = a.Eq(a)
	want = true
	if !got {
		t.Errorf("%v.Eq(%v) got %v, want true", a, a, want)
	}
}

func TestFieldElementNe(t *testing.T) {
	p := big.NewInt(13)
	a, _ := NewFieldElement(big.NewInt(7), p)
	b, _ := NewFieldElement(big.NewInt(6), p)
	got := a.Ne(b)
	if !got {
		t.Errorf("%v.Ne(%v) got %v, want true", a, b, got)
	}
	got = a.Ne(a)
	if got {
		t.Errorf("%v.Ne(%v) got %v, want false", a, a, got)
	}
}

func TestFieldElementAdd(t *testing.T) {
	p := big.NewInt(13)
	a, _ := NewFieldElement(big.NewInt(7), p)
	b, _ := NewFieldElement(big.NewInt(12), p)
	c, _ := NewFieldElement(big.NewInt(6), p)
	d, _ := a.Add(b)
	if d.Ne(c) {
		t.Errorf("%v.Add(%v) got %v, want %v", a, b, d, c)
	}
}

func TestFieldElementSub(t *testing.T) {
	p := big.NewInt(19)
	tests := []struct {
		x, y, want int64
	}{
		{11, 9, 2},
		{6, 13, 12},
	}
	for _, test := range tests {
		x, _ := NewFieldElement(big.NewInt(test.x), p)
		y, _ := NewFieldElement(big.NewInt(test.y), p)
		got, _ := x.Sub(y)
		want, _ := NewFieldElement(big.NewInt(test.want), p)
		if got.Ne(want) {
			t.Errorf("%v.Sub(%v) got %v, want %v", x, y, got, want)
		}
	}
}

func TestFieldElementMul(t *testing.T) {
	p := big.NewInt(13)
	x, _ := NewFieldElement(big.NewInt(3), p)
	y, _ := NewFieldElement(big.NewInt(12), p)
	got, _ := x.Mul(y)
	want, _ := NewFieldElement(big.NewInt(10), p)
	if got.Ne(want) {
		t.Errorf("%v.Mul(%v) got %v, want %v", x, y, got, want)
	}
}

func TestFieldElementDiv(t *testing.T) {
	p := big.NewInt(19)
	tests := []struct {
		num, den, want int64
	}{
		{2, 7, 3},
		{7, 5, 9},
	}
	for _, test := range tests {
		num, _ := NewFieldElement(big.NewInt(test.num), p)
		den, _ := NewFieldElement(big.NewInt(test.den), p)
		got, _ := num.Div(den)
		want, _ := NewFieldElement(big.NewInt(test.want), p)
		if got.Ne(want) {
			t.Errorf("%v.Div(%v) got %v, want %v", num, den, got, want)
		}
	}
}

func TestFieldElementPow(t *testing.T) {
	p := big.NewInt(13)
	tests := []struct {
		x, y, want int64
	}{
		{3, 3, 1},
		{7, -3, 8},
	}
	for _, test := range tests {
		x, _ := NewFieldElement(big.NewInt(test.x), p)
		got := x.Pow(big.NewInt(test.y))
		want, _ := NewFieldElement(big.NewInt(test.want), p)
		if got.Ne(want) {
			t.Errorf("%v.Pow(%v) got %v, want %v", x, test.y, got, want)
		}
	}
}

func TestPointString(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	x, _ := NewFieldElement(big.NewInt(192), p)
	y, _ := NewFieldElement(big.NewInt(105), p)
	point, _ := NewPoint(false, x, y, a, b)
	want := "Point(192,105)_0_7 FieldElement(223)"
	got := point.String()
	if got != want {
		t.Errorf("%v.String() got %v, want %v", point, got, want)
	}
}

func TestPointOnCurve(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	validPoints := []struct {
		x, y int64
	}{
		{192, 105},
		{17, 56},
		{1, 193},
	}
	for _, xy := range validPoints {
		x, _ := NewFieldElement(big.NewInt(xy.x), p)
		y, _ := NewFieldElement(big.NewInt(xy.y), p)
		point, err := NewPoint(false, x, y, a, b)
		if err != nil {
			t.Errorf("%v is not on curve", point)
		}
	}
	invalidPoints := []struct {
		x, y int64
	}{
		{200, 119},
		{42, 99},
	}
	for _, xy := range invalidPoints {
		x, _ := NewFieldElement(big.NewInt(xy.x), p)
		y, _ := NewFieldElement(big.NewInt(xy.y), p)
		point, err := NewPoint(false, x, y, a, b)
		if err == nil {
			t.Errorf("%v is on curve", point)
		}
	}
}

func TestPointAdd(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	tests := []struct {
		x1, y1, x2, y2, x3, y3 int64
	}{
		{170, 142, 60, 139, 220, 181},
		{47, 71, 17, 56, 215, 68},
		{143, 98, 76, 66, 47, 71},
	}
	for _, test := range tests {
		x1, _ := NewFieldElement(big.NewInt(test.x1), p)
		y1, _ := NewFieldElement(big.NewInt(test.y1), p)
		p1, _ := NewPoint(false, x1, y1, a, b)
		x2, _ := NewFieldElement(big.NewInt(test.x2), p)
		y2, _ := NewFieldElement(big.NewInt(test.y2), p)
		p2, _ := NewPoint(false, x2, y2, a, b)
		got, _ := p1.Add(p2)
		x3, _ := NewFieldElement(big.NewInt(test.x3), p)
		y3, _ := NewFieldElement(big.NewInt(test.y3), p)
		want, _ := NewPoint(false, x3, y3, a, b)
		if got.Ne(want) {
			t.Errorf("%v.Add(%v) got %v, want %v", p1, p2, got, want)
		}
	}
}

func TestPointScalarMul(t *testing.T) {
	p := big.NewInt(223)
	a, _ := NewFieldElement(big.NewInt(0), p)
	b, _ := NewFieldElement(big.NewInt(7), p)
	x, _ := NewFieldElement(big.NewInt(47), p)
	y, _ := NewFieldElement(big.NewInt(71), p)
	point, _ := NewPoint(false, x, y, a, b)
	tests := []struct {
		s, x, y int64
	}{
		{2, 36, 111},
		{4, 194, 51},
		{8, 116, 55},
	}
	for _, test := range tests {
		got := point.ScalarMul(big.NewInt(test.s))
		sx, _ := NewFieldElement(big.NewInt(test.x), p)
		sy, _ := NewFieldElement(big.NewInt(test.y), p)
		want, _ := NewPoint(false, sx, sy, a, b)
		if got.Ne(want) {
			t.Errorf("%v.ScalarMul(%v) got %v, want %v", point, test.s, got, want)
		}
	}
	s := big.NewInt(21)
	got := point.ScalarMul(s)
	z, _ := NewFieldElement(big.NewInt(0), p)
	want, _ := NewPoint(true, z, z, a, b)
	if got.Ne(want) {
		t.Errorf("%v.ScalarMul(%v) got %v, want %v", point, s, got, want)
	}
}

入出力結果

% go test
# bitcoin/ecc [bitcoin/ecc.test]
./ecc_test.go:140:14: undefined: NewPoint
FAIL	bitcoin/ecc [build failed]
% go test
--- FAIL: TestFieldElementEq (0.00s)
panic: runtime error: invalid memory address or nil pointer dereference [recovered]
	panic: runtime error: invalid memory address or nil pointer dereference
[signal SIGSEGV: segmentation violation code=0x1 addr=0x0 pc=0x110d1e1]

goroutine 18 [running]:
testing.tRunner.func1.1(0x1135c00, 0x123b830)
	/opt/local/lib/go/src/testing/testing.go:1072 +0x30d
testing.tRunner.func1(0xc000082600)
	/opt/local/lib/go/src/testing/testing.go:1075 +0x41a
panic(0x1135c00, 0x123b830)
	/opt/local/lib/go/src/runtime/panic.go:969 +0x1b9
math/big.(*Int).Cmp(0x0, 0xc00009be80, 0xc00009be80)
	/opt/local/lib/go/src/math/big/int.go:328 +0x41
bitcoin/ecc.FieldElement.Eq(0x0, 0x0, 0xc0000d9f20, 0x1, 0x1, 0xc00008e4d0)
	/.../go/src/bitcoin/ecc/ecc.go:34 +0x66
bitcoin/ecc.TestFieldElementEq(0xc000082600)
	/.../go/src/bitcoin/ecc/ecc_test.go:12 +0x185
testing.tRunner(0xc000082600, 0x1162a08)
	/opt/local/lib/go/src/testing/testing.go:1123 +0xef
created by testing.(*T).Run
	/opt/local/lib/go/src/testing/testing.go:1168 +0x2b3
exit status 2
FAIL	bitcoin/ecc	0.224s
% go test
PASS
ok  	bitcoin/ecc	0.260s
% go test
PASS
ok  	bitcoin/ecc	0.308s
% go test
# bitcoin/ecc [bitcoin/ecc.test]
./ecc_test.go:183:15: p1.Add undefined (type Point has no field or method Add)
FAIL	bitcoin/ecc [build failed]
% go test
PASS
ok  	bitcoin/ecc	0.320s
% go test
# bitcoin/ecc [bitcoin/ecc.test]
./ecc_test.go:208:15: point.ScalarMul undefined (type Point has no field or method ScalarMul)
./ecc_test.go:217:14: point.ScalarMul undefined (type Point has no field or method ScalarMul)
FAIL	bitcoin/ecc [build failed]
% go test
PASS
ok  	bitcoin/ecc	0.330s
% go test -v
=== RUN   TestFieldElementEq
--- PASS: TestFieldElementEq (0.00s)
=== RUN   TestFieldElementNe
--- PASS: TestFieldElementNe (0.00s)
=== RUN   TestFieldElementAdd
--- PASS: TestFieldElementAdd (0.00s)
=== RUN   TestFieldElementSub
--- PASS: TestFieldElementSub (0.00s)
=== RUN   TestFieldElementMul
--- PASS: TestFieldElementMul (0.00s)
=== RUN   TestFieldElementDiv
--- PASS: TestFieldElementDiv (0.00s)
=== RUN   TestFieldElementPow
--- PASS: TestFieldElementPow (0.00s)
=== RUN   TestPointString
--- PASS: TestPointString (0.00s)
=== RUN   TestPointOnCurve
--- PASS: TestPointOnCurve (0.00s)
=== RUN   TestPointAdd
--- PASS: TestPointAdd (0.02s)
=== RUN   TestPointScalarMul
--- PASS: TestPointScalarMul (0.09s)
PASS
ok  	bitcoin/ecc	0.393s
%