Example¶
Begin with the g4g example, but match this class's style.
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy import stats
xs = [5,7,8,7,2,17,2,9,4,11,12,9,6]
ys = [99,86,87,88,111,86,103,87,94,78,77,85,86]
m, b, _, _, _ = stats.linregress(xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, [m * x + b for x in xs])
plt.show()
Errors¶
For error, you computed sums of squared errors. Well done. You probably did something like follows:
errors = [(m * xs[i] + b - ys[i]) ** 2 for i in range(len(xs))]
print(errors)
[np.float64(21.626948352384407), np.float64(23.49288828029534), np.float64(4.39178485240792), np.float64(8.1051031441658), np.float64(129.88283706421814), np.float64(160.4258038281837), np.float64(11.536994532945045), np.float64(0.11859128985571113), np.float64(4.413400213657534), np.float64(34.12657393735711), np.float64(25.91326891120767), np.float64(5.496074733564338), np.float64(43.53669186049348)]
Calculate Sums¶
We can take the sum of squared errors by just adding these all up! Python has a handy built-in function to do that, or you can write your own.
print(sum(errors))
473.0669610007362
Machine Learning¶
Okay, so now how did machine learning work. Well, we picked a random m
and b
.
import random
I was fairly surprised students used random() without either adding or multiplying a value that is between zero and one.
Afterall, we've been working with functions that take values in those ranges, and make them much closer to our data so we know what they look like.
Pick a random point, find the slope and intercept of the line to the next point, and that is the first guess.
pt = int(random.random() * len(xs))
# I had Gemini do this.
m = (ys[pt] - ys[pt - 1]) / (xs[pt] - xs[pt - 1])
b = ys[pt] - m * xs[pt]
pt, m, b
(1, -6.5, 131.5)
Aside¶
Actually, forget machine learning, let's think about this - what is the largest and smallest possible m
and b
values under this model.
# you don't need to know how to do write this, but we had a lecture on how you could understand it.
get_m = lambda pt1, pt2 : (ys[pt1] - ys[pt2]) / (xs[pt1] - xs[pt2])
get_mb = lambda pt1, pt2 : (get_m(pt1,pt2), ys[pt1] - get_m(pt1,pt2) * xs[pt1])
mbs = [get_mb(pt1,pt2) for pt2 in range(len(xs)) for pt1 in range(len(xs)) if xs[pt1] != xs[pt2]]
print(max(mbs), min(mbs))
(5.0, 74.0) (-13.0, 164.0)
ms & bs¶
We can see what values m
and b
could fall between.
ms = [mb[0] for mb in mbs]
bs = [mb[1] for mb in mbs]
print('ms', min(ms), max(ms))
print('bs', min(bs), max(bs))
ms -13.0 5.0 bs 55.4 164.0
Brute force¶
Well that seems easy enough. Let's simply compute the sum of squared error for all possible pairs of m
and b
values and plot it. We have the code to take these two values and determine the sum of square errors from last class from Gemini, or we can write it ourselves.
sse = lambda m, b : sum([(m * xs[i] + b - ys[i]) ** 2 for i in range(len(xs))]) # S um of S quare E rror
print(sse(5,55.4), sse(-13,164.0))
10181.480000000001 34904.0
Clarification¶
I want to note something I do here - I'm not making a list! I'm making a list of lists. This list of lists is a LOT like an image. In fact, I'm going to make it into an image now that I've said that...
I should recall that sses[0][0] represents the line -13x + 55...
sses = [[sse(m, b) for m in range(-13,6)] for b in range(55,165)]
I wonder what minimum and maximum sse is. Little bit annoying with the two dimensions.
print(max([max(i) for i in sses]), min([min(i) for i in sses]))
259553 486
Look at it!¶
Oh I could look at it as a dataframe or even as an image!
import pandas as pd
import numpy as py
from PIL import Image as im
df = pd.DataFrame(sses)
df.head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 259553 | 229326 | 201025 | 174650 | 150201 | 127678 | 107081 | 88410 | 71665 | 56846 | 43953 | 32986 | 23945 | 16830 | 11641 | 8378 | 7041 | 7630 | 10145 |
1 | 256088 | 226059 | 197956 | 171779 | 147528 | 125203 | 104804 | 86331 | 69784 | 55163 | 42468 | 31699 | 22856 | 15939 | 10948 | 7883 | 6744 | 7531 | 10244 |
2 | 252649 | 222818 | 194913 | 168934 | 144881 | 122754 | 102553 | 84278 | 67929 | 53506 | 41009 | 30438 | 21793 | 15074 | 10281 | 7414 | 6473 | 7458 | 10369 |
3 | 249236 | 219603 | 191896 | 166115 | 142260 | 120331 | 100328 | 82251 | 66100 | 51875 | 39576 | 29203 | 20756 | 14235 | 9640 | 6971 | 6228 | 7411 | 10520 |
4 | 245849 | 216414 | 188905 | 163322 | 139665 | 117934 | 98129 | 80250 | 64297 | 50270 | 38169 | 27994 | 19745 | 13422 | 9025 | 6554 | 6009 | 7390 | 10697 |
Images¶
im.fromarray(np.array(sses).astype(np.uint8))
Colors¶
These all look bad because we aren't forcing the error values to be within the typical ranges of colors. Let's do that.
Colors should be between 0 and 255.
Errors get as large as 259553.
We divide by 259553 and multiple by 255, more or less.
im.fromarray(np.array([[i//(259553//255) for i in j] for j in sses]).astype(np.uint8))
Square Root¶
Well look at that - we have high (brighter) errors on the outside, and lower (less bright) error on the inside. I wonder where the minimum is!
Let's make the contrast less sharp by reversing the squaring - we take the square root of the sums, which should make it much lower.
259553 ** .5
509.4634432420053
Not bad - square root is almost enough not to have to scale any more at all. We squareroot and divide by 2. Errors are squared, after all.
im.fromarray(np.array([[(i**.5)//2 for i in j] for j in sses]).astype(np.uint8))
Plotly¶
Now that is getting somewhere. Still hard to see though. I wonder if plotly can solve this problem for us.
# prompt: plotly 3d surface plot of sses
import plotly.graph_objects as go
import plotly.io as pio # this stuff is just for the website
pio.renderers.default='notebook' # this stuff is just for the website
# Create the surface plot
fig = go.Figure(data=[go.Surface(z=sses)])
# Customize the plot
fig.update_layout(title='Sum of Squared Errors',
scene=dict(
xaxis_title='m',
yaxis_title='b',
zaxis_title='SSE'
))
# Display the plot
fig
Logs¶
And this is just seeing what sticks, but I used a non-linear function, logarithm, to try and "unsmooth" the gradient to its easier to see what I should be looking for.
import math
go.Figure(data=[go.Surface(x=list(range(-13,6)),y=list(range(55,165)),z=[[math.log(i) for i in j] for j in sses])])
Images, Again¶
We can see it looks relatively similar (from the top) to the PIL version once we use log for both.
We note log goes up to about 12, so we multiply by 15 to get more visible colors.
im.fromarray(np.array([[math.log(i)*15 for i in j] for j in sses]).astype(np.uint8))
Line them up¶
We can stretch this to be more square-like by duplicating values and more plotly like by inverting one axis.
im.fromarray(np.array([[math.log(sses[-j][i//5])*20 for i in range(len(sses[0]) * 5)] for j in range(len(sses))]).astype(np.uint8))
Machine Learning¶
Our job as machine learning people is to find where the error is lowest. We find it by picking a point, and finding which direction we move to get somewhere lower.
Gradient Descent¶
This is called gradient descent, and is probably the most powerful computation technique known at this point in history. It is much faster than computing every error, as we have done today, especially on meaningfully large data sets, but this is a good way to visualize how it works.
Check our work...¶
It looks to me like the minimum was around x=11, y=51. Let's see...
# Grabbing this from earlier - sses = [[sse(m, b) for m in range(-13,6)] for b in range(55,165)]
x = 11
y = 51
m = x - 13 # we started at -13, and didn't tell plotly. whoops.
b = y + 55
errors = [(m * xs[i] + b - ys[i]) ** 2 for i in range(len(xs))]
print(sum(errors)) # best possible was 473.0669610007362
499
m_stat, b_stat, _, _, _ = stats.linregress(xs, ys)
print(m_stat, b_stat)
print(m,b)
-1.7512877115526118 103.10596026490066 -2 106
For just look at the low point on a graph, I think that's pretty good. Remember - we didn't even check 1.75 because I only checked integer values. But nothing is stopping you for looking closer, especially once you know you're nearby!