03-least-squares.go 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package main
  2. import (
  3. "math/rand"
  4. "gonum.org/v1/plot"
  5. "gonum.org/v1/plot/plotter"
  6. "gonum.org/v1/plot/plotutil"
  7. "gonum.org/v1/plot/vg"
  8. )
  9. func main() {
  10. var a, b float64 = 0.7, 3
  11. points1 := plotter.XYs{}
  12. points2 := plotter.XYs{}
  13. for i := 0; i <= 10; i++ {
  14. points1 = append(points1, plotter.XY{
  15. X: float64(i),
  16. Y: a*float64(i) + b,
  17. })
  18. points2 = append(points2, plotter.XY{
  19. X: float64(i),
  20. Y: a*float64(i) + b + (2*rand.Float64() - 1),
  21. })
  22. }
  23. fa, fb := LeastSquares(points2)
  24. points3 := plotter.XYs{}
  25. for i := 0; i <= 10; i++ {
  26. points3 = append(points3, plotter.XY{
  27. X: float64(i),
  28. Y: fa*float64(i) + fb,
  29. })
  30. }
  31. plt := plot.New()
  32. plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10
  33. if err := plotutil.AddLinePoints(plt,
  34. "line1", points1,
  35. "line2", points2,
  36. "line3", points3,
  37. ); err != nil {
  38. panic(err)
  39. }
  40. if err := plt.Save(5*vg.Inch, 5*vg.Inch, "03-least-squares.png"); err != nil {
  41. panic(err)
  42. }
  43. }
  44. func LeastSquares(points plotter.XYs) (a, b float64) {
  45. var xSum, ySum float64
  46. for _, point := range points {
  47. xSum += point.X
  48. ySum += point.Y
  49. }
  50. xAvg, yAvg := xSum/float64(points.Len()), ySum/float64(points.Len())
  51. var xySum, xxSum float64
  52. for _, point := range points {
  53. xySum += (point.X - xAvg) * (point.Y - yAvg)
  54. xxSum += (point.X - xAvg) * (point.X - xAvg)
  55. }
  56. a = xySum / xxSum
  57. b = yAvg - a*xAvg
  58. return
  59. }