package bfa

import (
	"flag"
	"github.com/ethereum/go-ethereum/common"
	"log"
	"os"
	"testing"
)

const (
	NumBlocks int64 = 1000
)

var (
	node    *Node
	numbers map[string]int64
	hashes  map[string][]common.Hash
)

func BenchmarkBlockByNumber(b *testing.B) {
	for r, base := range numbers {
		b.Run(r, func(b *testing.B) {
			for i := int64(0); i < int64(b.N); i++ {
				_ = node.BlockByNumber(base + i%NumBlocks)
			}
		})
	}
}

func BenchmarkHeaderByNumber(b *testing.B) {
	for r, base := range numbers {
		b.Run(r, func(b *testing.B) {
			for i := int64(0); i < int64(b.N); i++ {
				_ = node.HeaderByNumber(base + i%NumBlocks)
			}
		})
	}
}

func BenchmarkBlockByHash(b *testing.B) {
	for r := range numbers {
		b.Run(r, func(b *testing.B) {
			for i := int64(0); i < int64(b.N); i++ {
				_ = node.BlockByHash(hashes[r][i%NumBlocks])
			}
		})
	}
}

func BenchmarkHeaderByHash(b *testing.B) {
	for r := range numbers {
		b.Run(r, func(b *testing.B) {
			for i := int64(0); i < int64(b.N); i++ {
				_ = node.HeaderByHash(hashes[r][i%NumBlocks])
			}
		})
	}
}

func (node *Node) getBlocksByNumber(last int64, n int64) int64 {
	block := node.BlockByNumber(last)
	for i := int64(0); i < n; i++ {
		block = node.BlockByNumber(block.Number.Int64() - 1)
	}
	return block.Number.Int64()
}

func (node *Node) getBlocksByHash(last int64, n int64) int64 {
	block := node.BlockByNumber(last)
	for i := int64(0); i < n; i++ {
		block = node.BlockByHash(block.ParentHash)
	}
	return block.Number.Int64()
}

func (node *Node) getHeadersByNumber(last int64, n int64) int64 {
	header := node.HeaderByNumber(last)
	for i := int64(0); i < n; i++ {
		header = node.HeaderByNumber(header.Number.Int64() - 1)
	}
	return header.Number.Int64()
}

func (node *Node) getHeadersByHash(last int64, n int64) int64 {
	header := node.HeaderByNumber(last)
	for i := int64(0); i < n; i++ {
		header = node.HeaderByHash(header.ParentHash)
	}
	return header.Number.Int64()
}

func TestBlockGetters(t *testing.T) {
	latest := node.BlockNumber()
	if latest < NumBlocks {
		t.Skip("No hay suficientes bloques")
	}
	t.Run("BlockByNumber", func(t *testing.T) {
		if node.getBlocksByNumber(latest, NumBlocks) != latest-NumBlocks {
			t.Fail()
		}
	})
	t.Run("BlockByHash", func(t *testing.T) {
		if node.getBlocksByHash(latest, NumBlocks) != latest-NumBlocks {
			t.Fail()
		}
	})
}

func TestHeaderGetters(t *testing.T) {
	latest := node.BlockNumber()
	if latest < NumBlocks {
		t.Skip("No hay suficientes bloques")
	}
	t.Run("HeaderByNumber", func(t *testing.T) {
		if node.getHeadersByNumber(latest, NumBlocks) != latest-NumBlocks {
			t.Fail()
		}
	})
	t.Run("HeaderByHash", func(t *testing.T) {
		if node.getHeadersByHash(latest, NumBlocks) != latest-NumBlocks {
			t.Fail()
		}
	})
}

func BenchmarkBlockGetters(b *testing.B) {
	latest := node.BlockNumber()
	if latest < NumBlocks {
		b.Skip("No hay suficientes bloques")
	}
	b.Run("BlockByNumber", func(b *testing.B) {
		for i := int64(0); i < int64(b.N); i++ {
			_ = node.getBlocksByNumber(latest, NumBlocks)
		}
	})
	b.Run("BlockByHash", func(b *testing.B) {
		for i := int64(0); i < int64(b.N); i++ {
			_ = node.getBlocksByHash(latest, NumBlocks)
		}
	})
}

func BenchmarkHeaderGetters(b *testing.B) {
	latest := node.BlockNumber()
	if latest < NumBlocks {
		b.Skip("No hay suficientes bloques")
	}
	b.Run("HeaderByNumber", func(b *testing.B) {
		for i := int64(0); i < int64(b.N); i++ {
			_ = node.getHeadersByNumber(latest, NumBlocks)
		}
	})
	b.Run("HeaderByHash", func(b *testing.B) {
		for i := int64(0); i < int64(b.N); i++ {
			_ = node.getHeadersByHash(latest, NumBlocks)
		}
	})
}

func TestMain(m *testing.M) {
	flag.Parse()
	var err error
	if node, err = Dial("http://localhost:8545"); err != nil {
		log.Fatal(err)
	}
	latest := node.BlockNumber()
	if latest < 3*NumBlocks {
		log.Fatal("No hay suficientes bloques como para correr benchmarks")
	}
	numbers = map[string]int64{"lo": 1, "mid": latest / 2, "hi": latest - NumBlocks}
	hashes = make(map[string][]common.Hash)
	for r, base := range numbers {
		for i := int64(0); i < NumBlocks; i++ {
			header := node.HeaderByNumber(base + 1)
			hashes[r] = append(hashes[r], header.GetHash())
		}
	}
	os.Exit(m.Run())

}
