計算機科学のブログ

楕円曲線暗号 楕円曲線のスカラー倍算 有限巡回群、位数 Goによるコーディング

プログラミング・ビットコイン ―ゼロからビットコインをプログラムする方法 (Jimmy Song(著)、中川 卓俊(監修)、住田 和則(監修)、中村 昭雄(監修)、星野 靖子(翻訳)、オライリー・ジャパン)の3章(楕円曲線暗号)、3.6(楕円曲線のスカラー倍算)の練習問題4、3.7(数学の群)の練習問題5の解答をPythonではなくGoで求めてみる。

4

コード

package main

import (
	"bitcoin/ecc"
	"fmt"
)

func main() {
	p := 223
	a, _ := ecc.NewFieldElement(0, p)
	b, _ := ecc.NewFieldElement(7, p)

	x, _ := ecc.NewFieldElement(192, p)
	y, _ := ecc.NewFieldElement(105, p)
	point, _ := ecc.NewPoint(false, x, y, a, b)
	point2, _ := point.Add(point)
	fmt.Printf("2 %v = %v\n", point, point2)

	x, _ = ecc.NewFieldElement(143, p)
	y, _ = ecc.NewFieldElement(98, p)
	point, _ = ecc.NewPoint(false, x, y, a, b)
	point2, _ = point.Add(point)
	fmt.Printf("2 %v = %v\n", point, point2)

	x, _ = ecc.NewFieldElement(47, p)
	y, _ = ecc.NewFieldElement(71, p)
	point, _ = ecc.NewPoint(false, x, y, a, b)
	for _, s := range []int{2, 4, 8, 21} {
		sPoint := point
		for i := 1; i < s; i++ {
			sPoint, _ = sPoint.Add(point)
		}
		fmt.Printf("%v %v = %v\n", s, point, sPoint)
	}
}

入出力結果

% go run ./main.go
2 Point(192,105)_0_7 FieldElement(223) = Point(49,71)_0_7 FieldElement(223)
2 Point(143,98)_0_7 FieldElement(223) = Point(64,168)_0_7 FieldElement(223)
2 Point(47,71)_0_7 FieldElement(223) = Point(36,111)_0_7 FieldElement(223)
4 Point(47,71)_0_7 FieldElement(223) = Point(194,51)_0_7 FieldElement(223)
8 Point(47,71)_0_7 FieldElement(223) = Point(116,55)_0_7 FieldElement(223)
21 Point(47,71)_0_7 FieldElement(223) = Point(Infinity)_0_7 FieldElement(223)
%

5

コード

package main

import (
	"bitcoin/ecc"
	"fmt"
)

func main() {
	p := 223
	a, _ := ecc.NewFieldElement(0, p)
	b, _ := ecc.NewFieldElement(7, p)
	x, _ := ecc.NewFieldElement(15, p)
	y, _ := ecc.NewFieldElement(86, p)
	point, _ := ecc.NewPoint(false, x, y, a, b)
	z, _ := ecc.NewFieldElement(0, p)
	inf, _ := ecc.NewPoint(true, z, z, a, b)
	for i := 1; ; i++ {
		t := point.ScalarMul(i)
		if t.Eq(inf) {
			fmt.Println("位数", i)
			break
		}
	}
}

入出力結果

% go run ./main.go
位数 7

3.8のスカラー倍算のコーディング、テスト

コード

ecc.go

// Package ecc (Elliptic Curve Cryptography, 楕円曲線暗号)
package ecc

import "fmt"

// FieldElement 有限体の要素
type FieldElement struct {
	Num   int
	Prime int
}

func (fe FieldElement) String() string {
	return fmt.Sprintf("FieldElement_%v(%v)", fe.Prime, fe.Num)
}

// NewFieldElement ...
func NewFieldElement(n, p int) (FieldElement, error) {
	if n >= p || n < 0 {
		return FieldElement{},
			fmt.Errorf("%v not in field range 0 to %v", n, p-1)
	}
	return FieldElement{n, p}, nil
}

// Eq ...
func Eq(fe FieldElement, fes ...FieldElement) bool {
	if fe.Prime == 0 {
		return false
	}
	num := fe.Num
	prime := fe.Prime
	for _, fe1 := range fes {
		if num != fe1.Num || prime != fe1.Prime {
			return false
		}
	}
	return true
	// if fe1.Prime == 0 || fe2.Prime == 0 {
	// 	return false
	// }
	// return fe1.Num == fe2.Num && fe1.Prime == fe2.Prime
}

// Ne ...
func Ne(fe FieldElement, fes ...FieldElement) bool {
	return !Eq(fe, fes...)
}

// Add ...
func Add(fe FieldElement, fes ...FieldElement) (FieldElement, error) {
	t := fe
	p := fe.Prime
	for _, fe1 := range fes {
		if t.Prime != fe1.Prime {
			return FieldElement{}, fmt.Errorf("Cannot add two numbers in different Fields")
		}
		t = FieldElement{Num: (t.Num + fe1.Num) % p, Prime: p}
	}
	return t, nil
}

// Sub ...
func Sub(fe FieldElement, fes ...FieldElement) (FieldElement, error) {
	n := fe.Num
	p := fe.Prime
	for _, fe1 := range fes {
		if fe1.Prime != p {
			return FieldElement{}, fmt.Errorf("Cannot sub two numbers in different Fields")
		}
		n = (n - fe1.Num) % p
		if n < 0 {
			n += p
		}
	}
	return FieldElement{Num: n, Prime: p}, nil
}

// Mul ...
func Mul(fe FieldElement, fes ...FieldElement) (FieldElement, error) {
	n := fe.Num
	p := fe.Prime
	if n == 0 {
		return fe, nil
	}
	t := fe
	for _, fe1 := range fes {
		if fe1.Prime != p {
			return FieldElement{}, fmt.Errorf("Cannot mul two numbers in different Fields")
		}
		if fe1.Num == 0 {
			return fe1, nil
		}
		s := t
		for i := 1; i < fe1.Num; i++ {
			t, _ = Add(t, s)
		}
	}
	return t, nil
}

// Div ...
func Div(fe1, fe2 FieldElement) (FieldElement, error) {
	if fe1.Prime != fe2.Prime {
		return FieldElement{},
			fmt.Errorf("Cannot div two numbers in diffrerent Fields")
	}
	if fe2.Num == 0 {
		return FieldElement{},
			fmt.Errorf("division by zero")
	}
	return Mul(fe1, Pow(fe2, fe2.Prime-2))
}

// Pow ...
func Pow(fe FieldElement, n int) FieldElement {
	if n < 0 {
		n += fe.Prime - 1
	}
	if n == 0 {
		t, _ := NewFieldElement(0, fe.Prime)
		return t
	}
	t := fe
	for i := 1; i < n; i++ {
		t, _ = Mul(t, fe)
	}
	return t
}

// Point 有限体上の楕円曲線 y^2 = x^3 + Ax + B の点
type Point struct {
	A, B     FieldElement
	Infinity bool
	X, Y     FieldElement
}

// NewPoint ...
func NewPoint(inf bool, x, y, a, b FieldElement) (Point, error) {
	p := x.Prime
	for _, fe := range []FieldElement{y, a, b} {
		if p != fe.Prime {
			return Point{}, fmt.Errorf("(%v, %v) in different Fields", x, fe)
		}
	}
	if inf {
		return Point{Infinity: true, A: a, B: b}, nil
	}
	l, _ := Mul(y, y)
	r1, _ := Mul(x, x, x)
	r2, _ := Mul(a, x)
	r, _ := Add(r1, r2, b)
	if Ne(l, r) {
		return Point{}, fmt.Errorf("(%v, %v) is not on the curve", x, y)
	}
	return Point{A: a, B: b, X: x, Y: y}, nil
}
func (p Point) String() string {
	prime := p.A.Prime
	a := p.A.Num
	b := p.B.Num
	if p.Infinity {
		return fmt.Sprintf("Point(Infinity)_%v_%v FieldElement(%v)",
			a, b, prime)
	}
	return fmt.Sprintf("Point(%v,%v)_%v_%v FieldElement(%v)",
		p.X.Num, p.Y.Num, a, b, prime)
}

// Eq ...
func (p Point) Eq(p1 Point) bool {
	return p.X == p1.X && p.Y == p1.Y && p.A == p1.A && p.B == p1.B
}

// Ne ...
func (p Point) Ne(p1 Point) bool {
	return !p.Eq(p1)
}

// Add ...
func (p Point) Add(p1 Point) (Point, error) {
	a := p.A
	b := p.B
	if a != p1.A || b != p1.B {
		return Point{},
			fmt.Errorf("Points %v, %v are not on the same curve", p, p1)
	}
	if p.Infinity {
		return p1, nil
	}
	if p1.Infinity {
		return p, nil
	}
	prime := a.Prime
	zero, _ := NewFieldElement(0, prime)
	if p.X == p1.X && p.Y != p1.Y {
		return NewPoint(true, zero, zero, p.A, p.B)
	}
	if p.Eq(p1) && Eq(p.Y, zero) {
		return NewPoint(true, zero, zero, p.A, p.B)
	}
	if p.Eq(p1) {
		n, err := NewFieldElement(3, prime)
		if err != nil {
			return Point{}, err
		}
		n, _ = Mul(n, p.X, p.X)
		n, err = Add(n, p.A)
		if err != nil {
			return Point{}, err
		}
		d, err := NewFieldElement(2, prime)
		if err != nil {
			return Point{}, err
		}
		d, err = Mul(d, p.Y)
		if err != nil {
			return Point{}, err
		}
		s, err := Div(n, d)
		if err != nil {
			return Point{}, err
		}
		l, _ := Mul(s, s)
		r, err := NewFieldElement(2, prime)
		if err != nil {
			return Point{}, err
		}
		r, err = Mul(r, p.X)
		if err != nil {
			return Point{}, err
		}
		x, err := Sub(l, r)
		if err != nil {
			return Point{}, err
		}
		y, err := Sub(p.X, x)
		if err != nil {
			return Point{}, err
		}
		y, err = Mul(s, y)
		if err != nil {
			return Point{}, err
		}
		y, err = Sub(y, p.Y)
		if err != nil {
			return Point{}, err
		}
		return NewPoint(false, x, y, p.A, p.B)
	}
	n, err := Sub(p1.Y, p.Y)
	if err != nil {
		return Point{}, err
	}
	d, err := Sub(p1.X, p.X)
	if err != nil {
		return Point{}, err
	}
	s, err := Div(n, d)
	if err != nil {
		return Point{}, err
	}
	x, _ := Mul(s, s)
	x, err = Sub(x, p.X, p1.X)
	if err != nil {
		return Point{}, err
	}
	y, err := Sub(p.X, x)
	if err != nil {
		return Point{}, err
	}
	y, err = Mul(s, y)
	if err != nil {
		return Point{}, err
	}
	y, err = Sub(y, p.Y)
	if err != nil {
		return Point{}, err
	}
	return NewPoint(false, x, y, p.A, p.B)
}

// ScalarMul s * (x, y)
func (p Point) ScalarMul(s int) Point {
	z, _ := NewFieldElement(0, p.A.Prime)
	sp, _ := NewPoint(true, z, z, p.A, p.B)
	coef := s
	cur := p
	for coef != 0 {
		if coef&1 != 0 {
			sp, _ = sp.Add(cur)
		}
		cur, _ = cur.Add(cur)
		coef >>= 1
	}
	return sp
}

ecc_test.go

package ecc

import (
	"testing"
)

func TestFieldElementEq(t *testing.T) {
	a, _ := NewFieldElement(7, 13)
	b, _ := NewFieldElement(6, 13)
	got := Eq(a, b)
	if got {
		t.Errorf("Eq(%v, %v) got %v, want false", a, b, got)
	}
	got = Eq(a, a)
	if !got {
		t.Errorf("Eq(%v, %v) got %v, want true", a, a, got)
	}
}
func TestFieldElementNe(t *testing.T) {
	a, _ := NewFieldElement(7, 13)
	b, _ := NewFieldElement(6, 13)
	got := Ne(a, b)
	if !got {
		t.Errorf("Ne(%v, %v) got %v, want true", a, b, got)
	}
	got = Ne(a, a)
	if got {
		t.Errorf("Ne(%v, %v) got %v, want false", a, a, got)
	}
}

func TestAdd(t *testing.T) {
	a, _ := NewFieldElement(7, 13)
	b, _ := NewFieldElement(12, 13)
	c, _ := NewFieldElement(6, 13)
	d, _ := Add(a, b)
	if Ne(d, c) {
		t.Errorf("Add(%v, %v) got %v, want %v", a, b, d, c)
	}
}

func TestSub(t *testing.T) {
	prime := 19
	tests := []struct {
		nums    []int
		wantNum int
	}{
		{[]int{11, 9}, 2},
		{[]int{6, 13}, 12},
	}
	for _, test := range tests {
		a, _ := NewFieldElement(test.nums[0], prime)
		b, _ := NewFieldElement(test.nums[1], prime)
		want, _ := NewFieldElement(test.wantNum, prime)
		got, _ := Sub(a, b)
		if Ne(got, want) {
			t.Errorf("Sub(%v, %v) got %v, want %v", a, b, got, want)
		}
	}
}

func TestMul(t *testing.T) {
	p := 13
	a, _ := NewFieldElement(3, p)
	b, _ := NewFieldElement(12, p)
	got, _ := Mul(a, b)
	want, _ := NewFieldElement(10, p)
	if Ne(got, want) {
		t.Errorf("Mul(%v, %v) got %v, want %v", a, b, got, want)
	}
}

func TestDiv(t *testing.T) {
	prime := 19
	tests := []struct {
		num, den, want int
	}{
		{2, 7, 3},
		{7, 5, 9},
	}
	for _, test := range tests {
		num, _ := NewFieldElement(test.num, prime)
		den, _ := NewFieldElement(test.den, prime)
		got, _ := Div(num, den)
		want, _ := NewFieldElement(test.want, prime)
		if Ne(got, want) {
			t.Errorf("Div(%v, %v) got %v, want %v", num, den, got, want)
		}
	}
}
func TestPow(t *testing.T) {
	prime := 13
	tests := []struct {
		n, p, want int
	}{
		{3, 3, 1},
		{7, -3, 8},
	}
	for _, test := range tests {
		a, _ := NewFieldElement(test.n, prime)
		got := Pow(a, test.p)
		want, _ := NewFieldElement(test.want, prime)
		if Ne(got, want) {
			t.Errorf("Pow(%v, %v) got %v, want %v", a, test.p, got, want)
		}
	}
}

func TestPointString(t *testing.T) {
	p := 223
	a, _ := NewFieldElement(0, p)
	b, _ := NewFieldElement(7, p)
	x, _ := NewFieldElement(192, p)
	y, _ := NewFieldElement(105, p)
	point, _ := NewPoint(false, x, y, a, b)
	want := "Point(192,105)_0_7 FieldElement(223)"
	got := point.String()
	if got != want {
		t.Errorf("%v.String() got %v, want %v", point, got, want)
	}
}

func TestPointOnCurve(t *testing.T) {
	p := 223
	a, _ := NewFieldElement(0, p)
	b, _ := NewFieldElement(7, p)
	validPoints := []struct {
		x, y int
	}{
		{192, 105},
		{17, 56},
		{1, 193},
	}
	for _, xy := range validPoints {
		x, _ := NewFieldElement(xy.x, p)
		y, _ := NewFieldElement(xy.y, p)
		point, err := NewPoint(false, x, y, a, b)
		if err != nil {
			t.Errorf("%v is not on curve", point)
		}
	}
	invalidPoints := []struct {
		x, y int
	}{
		{200, 119},
		{42, 99},
	}
	for _, xy := range invalidPoints {
		x, _ := NewFieldElement(xy.x, p)
		y, _ := NewFieldElement(xy.y, p)
		point, err := NewPoint(false, x, y, a, b)
		if err == nil {
			t.Errorf("%v is on curve", point)
		}
	}
}

func TestPointAdd(t *testing.T) {
	p := 223
	a, _ := NewFieldElement(0, p)
	b, _ := NewFieldElement(7, p)
	tests := []struct {
		x1, y1, x2, y2, x3, y3 int
	}{
		{170, 142, 60, 139, 220, 181},
		{47, 71, 17, 56, 215, 68},
		{143, 98, 76, 66, 47, 71},
	}
	for _, test := range tests {
		x1, _ := NewFieldElement(test.x1, p)
		y1, _ := NewFieldElement(test.y1, p)
		p1, _ := NewPoint(false, x1, y1, a, b)
		x2, _ := NewFieldElement(test.x2, p)
		y2, _ := NewFieldElement(test.y2, p)
		p2, _ := NewPoint(false, x2, y2, a, b)
		got, _ := p1.Add(p2)
		x3, _ := NewFieldElement(test.x3, p)
		y3, _ := NewFieldElement(test.y3, p)
		want, _ := NewPoint(false, x3, y3, a, b)
		if got.Ne(want) {
			t.Errorf("%v.Add(%v) got %v, want %v", p1, p2, got, want)
		}
	}
}

func TestPointScalarMul(t *testing.T) {
	p := 223
	a, _ := NewFieldElement(0, p)
	b, _ := NewFieldElement(7, p)
	x, _ := NewFieldElement(47, p)
	y, _ := NewFieldElement(71, p)
	point, _ := NewPoint(false, x, y, a, b)
	tests := []struct {
		s, x, y int
	}{
		{2, 36, 111},
		{4, 194, 51},
		{8, 116, 55},
	}
	for _, test := range tests {
		got := point.ScalarMul(test.s)
		sx, _ := NewFieldElement(test.x, p)
		sy, _ := NewFieldElement(test.y, p)
		want, _ := NewPoint(false, sx, sy, a, b)
		if got.Ne(want) {
			t.Errorf("%v.ScalarMul(%v) got %v, want %v", point, test.s, got, want)
		}
	}
	s := 21
	got := point.ScalarMul(s)
	z, _ := NewFieldElement(0, p)
	want, _ := NewPoint(true, z, z, a, b)
	if got.Ne(want) {
		t.Errorf("%v.ScalarMul(%v) got %v, want %v", point, s, got, want)
	}
}

入出力結果

% go test
# bitcoin/ecc [bitcoin/ecc.test]
./ecc_test.go:201:15: point.ScalarMul undefined (type Point has no field or method ScalarMul)
FAIL	bitcoin/ecc [build failed]
% go test
PASS
ok  	bitcoin/ecc	0.322s
% go test
PASS
ok  	bitcoin/ecc	0.305s
% go test
PASS
ok  	bitcoin/ecc	0.231s
%