08-nelder-mead-with-recorder.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package main
  2. import (
  3. "image/color"
  4. "math"
  5. "math/rand"
  6. "gonum.org/v1/gonum/optimize"
  7. "gonum.org/v1/plot"
  8. "gonum.org/v1/plot/palette/moreland"
  9. "gonum.org/v1/plot/plotter"
  10. "gonum.org/v1/plot/vg"
  11. "gonum.org/v1/plot/vg/draw"
  12. )
  13. func main() {
  14. points := plotter.XYs{}
  15. for i := 0; i < 10; i++ {
  16. points = append(points, plotter.XY{
  17. X: 100 * rand.Float64(),
  18. Y: 100 * rand.Float64(),
  19. })
  20. }
  21. scatter, err := plotter.NewScatter(points)
  22. if err != nil {
  23. panic(err)
  24. }
  25. scatter.Shape = draw.CircleGlyph{}
  26. heatmap := plotter.NewHeatMap(Heat(points), moreland.SmoothBlueRed().Palette(100))
  27. Func := func(x []float64) float64 {
  28. if len(x) != 2 {
  29. panic("illegal x")
  30. }
  31. var sum float64
  32. for _, point := range points {
  33. sum += math.Sqrt(math.Pow(point.X-x[0], 2) + math.Pow(point.Y-x[1], 2))
  34. }
  35. return sum
  36. }
  37. problem := optimize.Problem{
  38. Func: Func,
  39. }
  40. recorder := &Recorder{}
  41. result, err := optimize.Minimize(problem, []float64{1, 1}, &optimize.Settings{
  42. Recorder: recorder,
  43. }, &optimize.NelderMead{})
  44. if err != nil {
  45. panic(err)
  46. }
  47. pathLines, pathPoints, err := plotter.NewLinePoints(recorder.XYs)
  48. if err != nil {
  49. panic(err)
  50. }
  51. aim, err := plotter.NewScatter(plotter.XYs{{
  52. X: result.X[0],
  53. Y: result.X[1],
  54. }})
  55. if err != nil {
  56. panic(err)
  57. }
  58. aim.Shape = draw.CircleGlyph{}
  59. aim.Color = color.White
  60. plt := plot.New()
  61. plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 100, 100
  62. plt.Add(heatmap, scatter, pathPoints, pathLines, aim)
  63. if err := plt.Save(5*vg.Inch, 5*vg.Inch, "08-nelder-mead-with-recorder.png"); err != nil {
  64. panic(err)
  65. }
  66. }
  67. type Heat plotter.XYs
  68. func (h Heat) Dims() (c, r int) { return 100, 100 }
  69. func (h Heat) X(c int) float64 { return float64(c) }
  70. func (h Heat) Y(r int) float64 { return float64(r) }
  71. func (h Heat) Z(c, r int) float64 {
  72. var sum float64
  73. for _, p := range h {
  74. sum += math.Sqrt(math.Pow(p.X-h.X(c), 2) + math.Pow(p.Y-h.Y(r), 2))
  75. }
  76. return -sum
  77. }
  78. type Recorder struct {
  79. XYs plotter.XYs
  80. }
  81. func (r *Recorder) Init() error {
  82. return nil
  83. }
  84. func (r *Recorder) Record(location *optimize.Location, op optimize.Operation, _ *optimize.Stats) error {
  85. if op != optimize.MajorIteration && op != optimize.InitIteration && op != optimize.PostIteration {
  86. return nil
  87. }
  88. r.XYs = append(r.XYs, plotter.XY{
  89. X: location.X[0],
  90. Y: location.X[1],
  91. })
  92. return nil
  93. }