10-optimize-fit.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package main
  2. import (
  3. "math"
  4. "math/rand"
  5. "gonum.org/v1/gonum/optimize"
  6. "gonum.org/v1/plot"
  7. "gonum.org/v1/plot/plotter"
  8. "gonum.org/v1/plot/plotutil"
  9. "gonum.org/v1/plot/vg"
  10. )
  11. func main() {
  12. var a, b float64 = 0.7, 3
  13. points1 := plotter.XYs{}
  14. points2 := plotter.XYs{}
  15. for i := 0; i <= 10; i++ {
  16. points1 = append(points1, plotter.XY{
  17. X: float64(i),
  18. Y: a*float64(i) + b,
  19. })
  20. points2 = append(points2, plotter.XY{
  21. X: float64(i),
  22. Y: a*float64(i) + b + (2*rand.Float64() - 1),
  23. })
  24. }
  25. result, err := optimize.Minimize(optimize.Problem{
  26. Func: func(x []float64) float64 {
  27. if len(x) != 2 {
  28. panic("illegal x")
  29. }
  30. a := x[0]
  31. b := x[1]
  32. var sum float64
  33. for _, point := range points2 {
  34. y := a*point.X + b
  35. sum += math.Abs(y - point.Y)
  36. }
  37. return sum
  38. },
  39. }, []float64{1, 1}, &optimize.Settings{}, &optimize.NelderMead{})
  40. if err != nil {
  41. panic(err)
  42. }
  43. fa, fb := result.X[0], result.X[1]
  44. points3 := plotter.XYs{}
  45. for i := 0; i <= 10; i++ {
  46. points3 = append(points3, plotter.XY{
  47. X: float64(i),
  48. Y: fa*float64(i) + fb,
  49. })
  50. }
  51. plt := plot.New()
  52. plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10
  53. if err := plotutil.AddLinePoints(plt,
  54. "line1", points1,
  55. "line2", points2,
  56. "line3", points3,
  57. ); err != nil {
  58. panic(err)
  59. }
  60. if err := plt.Save(5*vg.Inch, 5*vg.Inch, "10-optimize-fit.png"); err != nil {
  61. panic(err)
  62. }
  63. }