from fastai2.tabular.all import *
= pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/titanic.csv') titanic
Currently using fastai2.
Installation via pip install fastai2
.
Oversampling
= ['survived', 'pclass', 'sex', 'embarked', 'class', 'who', 'adult_male', 'deck', 'embark_town', 'alone']
cat_vars = ['age', 'sibsp', 'parch', 'fare']
cont_vars = 'alive'
dep_var = [Categorify, FillMissing, Normalize]
procs = CategoryBlock()
block_y = RandomSplitter()(range_of(titanic)) splits
= TabularPandas(titanic, procs=procs, cat_names=cat_vars, cont_names=cont_vars,
to =dep_var, block_y=block_y, splits=splits) y_names
to.items
survived | pclass | sex | age | sibsp | parch | fare | embarked | class | who | adult_male | deck | embark_town | alive | alone | age_na | sibsp_na | parch_na | fare_na | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
143 | 1 | 3 | 2 | -0.799442 | -0.464437 | -0.467832 | -0.509204 | 2 | 3 | 2 | 2 | 0 | 2 | 0 | 2 | 1 | 1 | 1 | 1 |
376 | 2 | 3 | 1 | -0.568123 | -0.464437 | -0.467832 | -0.498856 | 3 | 3 | 3 | 1 | 0 | 3 | 1 | 2 | 1 | 1 | 1 | 1 |
853 | 2 | 1 | 1 | -1.030761 | -0.464437 | 0.815107 | 0.166542 | 3 | 1 | 3 | 1 | 4 | 3 | 1 | 1 | 1 | 1 | 1 | 1 |
333 | 1 | 3 | 2 | -1.030761 | 1.242489 | -0.467832 | -0.276367 | 3 | 3 | 2 | 2 | 0 | 3 | 0 | 1 | 1 | 1 | 1 | 1 |
791 | 1 | 2 | 2 | -1.030761 | -0.464437 | -0.467832 | -0.110793 | 3 | 2 | 2 | 2 | 0 | 3 | 0 | 2 | 1 | 1 | 1 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
206 | 1 | 3 | 2 | 0.202940 | 0.389026 | -0.467832 | -0.320865 | 3 | 3 | 2 | 2 | 0 | 3 | 0 | 1 | 1 | 1 | 1 | 1 |
654 | 1 | 3 | 1 | -0.876548 | -0.464437 | -0.467832 | -0.509204 | 2 | 3 | 3 | 1 | 0 | 2 | 0 | 2 | 1 | 1 | 1 | 1 |
802 | 2 | 1 | 2 | -1.416293 | 0.389026 | 2.098045 | 1.834694 | 3 | 1 | 1 | 1 | 2 | 3 | 1 | 1 | 1 | 1 | 1 | 1 |
257 | 2 | 1 | 1 | 0.048727 | -0.464437 | -0.467832 | 1.141355 | 3 | 1 | 3 | 1 | 2 | 3 | 1 | 2 | 1 | 1 | 1 | 1 |
41 | 1 | 2 | 1 | -0.182592 | 0.389026 | -0.467832 | -0.214277 | 3 | 2 | 3 | 1 | 0 | 3 | 0 | 1 | 1 | 1 | 1 | 1 |
891 rows × 19 columns
; to.items.alive.value_counts().plot.bar()
= to.dataloaders() dls
= tabular_learner(dls, [200,100], metrics=accuracy) learn
The get_idxs
method of DataLoaders samples the idxs
to be used in an epoch of training:
10] learn.dls.get_idxs()[:
[412, 122, 391, 38, 310, 400, 316, 320, 57, 500]
len(learn.dls.get_idxs())
700
= pd.DataFrame(1 / learn.dls.items.alive.value_counts()) df
= learn.dls.items.alive.apply(lambda x: df.to_dict()['alive'][x]) weights
=700, weights=weights, replace=True).alive.value_counts() learn.dls.train.items.sample(n
0 352
1 348
Name: alive, dtype: int64
def sample_epoch(self):
= pd.DataFrame(1 / self.items.alive.value_counts())
df = self.items.alive.apply(lambda x: df.to_dict()['alive'][x])
weights self.n = 700 # going for bigger epochs as this will require quite a bit of training
= self.items.sample(n=self.n, weights=weights, replace=True).index
oversampled_idxs return [np.where(self.items.index == i)[0][0] for i in oversampled_idxs]
= types.MethodType(sample_epoch, learn.dls.train) learn.dls.train.get_idxs
= learn.dls.train.get_idxs()
idxs learn.dls.train.items.iloc[idxs].alive.value_counts()
1 360
0 340
Name: alive, dtype: int64
10, 1e-3) learn.fit(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.401528 | 0.673750 | 0.657303 | 00:00 |
1 | 0.231258 | 0.608156 | 0.792135 | 00:00 |
2 | 0.144609 | 0.453062 | 0.955056 | 00:00 |
3 | 0.098268 | 0.222973 | 1.000000 | 00:00 |
4 | 0.070314 | 0.066305 | 1.000000 | 00:00 |
5 | 0.052205 | 0.014175 | 1.000000 | 00:00 |
6 | 0.039717 | 0.002960 | 1.000000 | 00:00 |
7 | 0.030711 | 0.000851 | 1.000000 | 00:00 |
8 | 0.024077 | 0.000408 | 1.000000 | 00:00 |
9 | 0.019053 | 0.000268 | 1.000000 | 00:00 |
= tabular_learner(dls, [200,100], metrics=accuracy) learn
10, 1e-3) learn.fit(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.328865 | 0.669253 | 0.775281 | 00:00 |
1 | 0.180982 | 0.614954 | 0.943820 | 00:00 |
2 | 0.110680 | 0.484824 | 0.977528 | 00:00 |
3 | 0.074929 | 0.271870 | 1.000000 | 00:00 |
4 | 0.053904 | 0.095419 | 1.000000 | 00:00 |
5 | 0.040015 | 0.020298 | 1.000000 | 00:00 |
6 | 0.030432 | 0.003644 | 1.000000 | 00:00 |
7 | 0.023582 | 0.000864 | 1.000000 | 00:00 |
8 | 0.018468 | 0.000273 | 1.000000 | 00:00 |
9 | 0.014615 | 0.000178 | 1.000000 | 00:00 |