-
Notifications
You must be signed in to change notification settings - Fork 59
Cross Validation Added #407
base: master
Are you sure you want to change the base?
Conversation
gramex/handlers/mlhandler.py
Outdated
| from tornado.web import HTTPError | ||
| from sklearn.metrics import get_scorer | ||
| from sklearn.model_selection import cross_val_predict, cross_val_score | ||
| from sklearn.model_selection import cross_val_predict, cross_val_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line appears twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra line is unnecessary.
gramex/handlers/mlhandler.py
Outdated
| target = data[target_col] | ||
| train = data[[c for c in data if c != target_col]] | ||
| # cross validation | ||
| mod = cls.modelFunction() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not required. The model is already present as cls.model, see line no: 116.
gramex/handlers/mlhandler.py
Outdated
| # cross validation | ||
| mod = cls.modelFunction() | ||
| CVscore = cross_val_score(mod, train, target) | ||
| CV = sum(CVscore)/len(CVscore) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Use
CVscore.mean() - Variable naming has to follow a specified style - do
pip install flake8and run theflake8command against this file, i.e.flake8 mlhandler.py, and check the output.
|
@prakrutisingh24 In this PR, we are just computing the cross val score when the model is set up for the first time, and simply printing the CV score. What we need is:
Thanks, |
gramex/handlers/mlhandler.py
Outdated
| 'cats': [], | ||
| 'target_col': None, | ||
| 'CV': True, | ||
| 'CVargs': [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a single argument, cv, which can take any value, i.e in gramex.yaml, users should be able to write any of the following.
cv: false # disable cross validation
cv: 5 # Use 5 folds
cv:
cv: 8 # Use 8 folds
n_jobs: -1 # with an optional other parameter.
gramex/handlers/mlhandler.py
Outdated
| from tornado.web import HTTPError | ||
| from sklearn.metrics import get_scorer | ||
| from sklearn.model_selection import cross_val_predict, cross_val_score | ||
| from sklearn.model_selection import cross_val_predict, cross_val_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra line is unnecessary.
gramex/handlers/mlhandler.py
Outdated
| # cross validation | ||
| print('yayyy we are here') | ||
| cls.CrossValidation(train,target) | ||
| print('should have printed') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the prints.
jaidevd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The training is happening in def _fit. Cross validation should also happen there.
gramex/handlers/mlhandler.py
Outdated
| from tornado.web import HTTPError | ||
| from sklearn.metrics import get_scorer | ||
| from sklearn.model_selection import cross_val_predict, cross_val_score | ||
| from ast import literal_eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be required.
gramex/handlers/mlhandler.py
Outdated
| 'nums': [], | ||
| 'cats': [], | ||
| 'target_col': None, | ||
| 'CV': True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it lowercase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to support three cases for the cv option:
- If the user sets
cv: false- then no cross validation happens - If the user sets
cv: 4(or some other integer) pass it straight tocross_val_score - The default should be
cv: None, and in this case, the user should not have to write anything in gramex.yaml
gramex/handlers/mlhandler.py
Outdated
| target = data[target_col] | ||
| train = data[[c for c in data if c != target_col]] | ||
| # cross validation | ||
| cls.CrossValidation(train,target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it lowercase.
gramex/handlers/mlhandler.py
Outdated
| mclass = model_kwargs.get('class', False) | ||
| if mclass: | ||
| model = search_modelclass(mclass)(**model_kwargs.get('params', {})) | ||
| return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not required.
| if CV: | ||
| CVscore = cross_val_score(mod, X=train, y=target, **literal_eval(json.dumps(CV))) | ||
| CVavg = sum(CVscore)/len(CVscore) | ||
| print('Cross Validation Score : ',CVavg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CV should take place within the train method only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cv:
cvscore = cross_val_score(mod, X=train, y=target, cv=cv)
else:
# Do the usual .fit| target = data[target_col] | ||
| train = data[[c for c in data if c != target_col]] | ||
| # cross validation | ||
| cls.cross_validation(train,target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required here.
No description provided.