09-other-methods.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package main
  2. import (
  3. "fmt"
  4. "image/color"
  5. "math"
  6. "math/rand"
  7. xrand "golang.org/x/exp/rand"
  8. "gonum.org/v1/gonum/mat"
  9. "gonum.org/v1/gonum/optimize"
  10. "gonum.org/v1/gonum/spatial/r1"
  11. "gonum.org/v1/gonum/stat/distmv"
  12. "gonum.org/v1/plot"
  13. "gonum.org/v1/plot/palette/moreland"
  14. "gonum.org/v1/plot/plotter"
  15. "gonum.org/v1/plot/vg"
  16. "gonum.org/v1/plot/vg/draw"
  17. )
  18. func main() {
  19. points := plotter.XYs{}
  20. for i := 0; i < 10; i++ {
  21. points = append(points, plotter.XY{
  22. X: 100 * rand.Float64(),
  23. Y: 100 * rand.Float64(),
  24. })
  25. }
  26. scatter, err := plotter.NewScatter(points)
  27. if err != nil {
  28. panic(err)
  29. }
  30. scatter.Shape = draw.CircleGlyph{}
  31. heatmap := plotter.NewHeatMap(Heat(points), moreland.SmoothBlueRed().Palette(100))
  32. Func := func(x []float64) float64 {
  33. if len(x) != 2 {
  34. panic("illegal x")
  35. }
  36. var sum float64
  37. for _, point := range points {
  38. sum += math.Sqrt(math.Pow(point.X-x[0], 2) + math.Pow(point.Y-x[1], 2))
  39. }
  40. return sum
  41. }
  42. Grad := func(grad, x []float64) {
  43. if len(grad) != len(x) {
  44. panic("illegal grad or x")
  45. }
  46. delta := 1e-9
  47. for i, v := range x {
  48. x[i] = v - delta
  49. f1 := Func(x)
  50. x[i] = v + delta
  51. f2 := Func(x)
  52. x[i] = v
  53. grad[i] = (f2 - f1) / (2 * delta)
  54. }
  55. }
  56. problem := optimize.Problem{
  57. Func: Func,
  58. Grad: Grad,
  59. }
  60. methods := []struct {
  61. Name string
  62. Method optimize.Method
  63. }{
  64. {"BFGS", &optimize.BFGS{}},
  65. {"CG", &optimize.CG{}},
  66. {"CmaEsChol", &optimize.CmaEsChol{}},
  67. {"GradientDescent", &optimize.GradientDescent{}},
  68. {"GuessAndCheck", &optimize.GuessAndCheck{
  69. Rander: distmv.NewUniform([]r1.Interval{{0, 100}, {0, 100}}, xrand.NewSource(0)),
  70. }},
  71. {"LBFGS", &optimize.LBFGS{}},
  72. {"ListSearch", &optimize.ListSearch{
  73. Locs: mat.NewDense(6, 2, []float64{
  74. 0, 10,
  75. 20, 30,
  76. 40, 50,
  77. 60, 70,
  78. 80, 90,
  79. 90, 100,
  80. }),
  81. }},
  82. {"NelderMead", &optimize.NelderMead{}},
  83. //{"Newton", &optimize.Newton{}}, // what a pity
  84. }
  85. for _, method := range methods {
  86. recorder := &Recorder{}
  87. result, err := optimize.Minimize(problem, []float64{1, 1}, &optimize.Settings{
  88. Recorder: recorder,
  89. }, method.Method)
  90. if err != nil {
  91. panic(err)
  92. }
  93. pathLines, pathPoints, err := plotter.NewLinePoints(recorder.XYs)
  94. if err != nil {
  95. panic(err)
  96. }
  97. aim, err := plotter.NewScatter(plotter.XYs{{
  98. X: result.X[0],
  99. Y: result.X[1],
  100. }})
  101. if err != nil {
  102. panic(err)
  103. }
  104. aim.Shape = draw.CircleGlyph{}
  105. aim.Color = color.White
  106. plt := plot.New()
  107. plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 100, 100
  108. plt.Title.Text = method.Name
  109. plt.Add(heatmap, scatter, pathPoints, pathLines, aim)
  110. if err := plt.Save(5*vg.Inch, 5*vg.Inch, fmt.Sprintf("09-other-methods.%s.png", method.Name)); err != nil {
  111. panic(err)
  112. }
  113. }
  114. }
  115. type Heat plotter.XYs
  116. func (h Heat) Dims() (c, r int) { return 100, 100 }
  117. func (h Heat) X(c int) float64 { return float64(c) }
  118. func (h Heat) Y(r int) float64 { return float64(r) }
  119. func (h Heat) Z(c, r int) float64 {
  120. var sum float64
  121. for _, p := range h {
  122. sum += math.Sqrt(math.Pow(p.X-h.X(c), 2) + math.Pow(p.Y-h.Y(r), 2))
  123. }
  124. return -sum
  125. }
  126. type Recorder struct {
  127. XYs plotter.XYs
  128. }
  129. func (r *Recorder) Init() error {
  130. return nil
  131. }
  132. func (r *Recorder) Record(location *optimize.Location, op optimize.Operation, _ *optimize.Stats) error {
  133. if op != optimize.MajorIteration && op != optimize.InitIteration && op != optimize.PostIteration {
  134. return nil
  135. }
  136. r.XYs = append(r.XYs, plotter.XY{
  137. X: location.X[0],
  138. Y: location.X[1],
  139. })
  140. return nil
  141. }