【SAS】glmselectプロシジャ_Lasso回帰、ElasticNet回帰

前回 hpgenselect プロシジャで罰則付きロジスティック回帰を実行してみた。
【SAS】hpgenselect プロシジャ_罰則付きロジスティック回帰(LASSO) - こちにぃるの日記

目的変数が連続値の時のLASSO回帰をやってなかったのでまとめてみる。
ついでにElasticNet回帰もやってみる。

【目次】


使用データ


SASHELP にあるベースボールデータを使用する。選手のヒット数、エラー数などの成績から年俸(対数)を予測してみる。

*-- データ;
data ads;
    set sashelp.baseball;
    if cmiss(of n:, of Cr:, logSalary, Division, League) = 0;
    if Division="West"   then Division_West  =1; else Division_West  =0; *One-hot encoding;
    if League="American" then League_American=1; else League_American=0; *One-hot encoding;
run;


データ分割


ベースボールデータの2割をテストデータとする。
ホールドアウト法用に残り8割を 7 : 3 = 訓練データ : バリデーションデータ に細分する。

*-- データ分割;
proc surveyselect data=ads rate=0.2 out=ads outall method=srs seed=19;
run;
data train(drop=selected) test(drop=selected);
    set ads;
    if selected=1 then output test; else output train;
run;
 
proc surveyselect data=train rate=0.3 out=train outall method=srs seed=19;
run;
 
proc format;
    value rollf 0="train-data" 1="validation-data" 2="test-data";
run;
data alldata;
    format selected selected2 rollf.; *-- フォーマット適用;
    set train test(in=in1);
    if in1 then do;
         selected =2; *-- テストデータに番号振る;
         selected2=2;
    end;
    else selected2=0;
run;


LASSO回帰 Hold-out法


ホールドアウト法においては、partition statement で訓練データと検証データを指定できる。
なお、model statement の option に stb を入れると偏回帰係数を出してくれる。

*-- ラッソ回帰;
ods graphics on;
proc glmselect data=train plots=all seed=19;
    partition roleVar=selected(train='0' validate='1');
    model logSalary = nAtBat nHits nHome nRuns nRBI nBB nOuts nAssts nError
                      CrAtBat CrHits CrHome CrRuns CrRbi CrBB
                      Division_West League_American
                      / details=all stats=all
                        selection=lasso(choose=validate stop=none) stb;
    code file="glmselect_lasso.sas";
run;
ods graphics off;
 
data pred_lasso;
    format selected selected2 logSalary P_logSalary;
    set alldata;
    %include glmselect_lasso;
run;

f:id:cochineal19:20210711162118p:plain:w300

LASSO回帰 Cross-validation法


クロスバリデーションでは model statementselection=lasso(choose=cv) cvmethod=random(10) とする。
random(k) で k-fold cross-validation になる。なお、seedも指定可能。

*-- ラッソ回帰 CV;
ods graphics on;
proc glmselect data=train plots=all seed=19;
    model logSalary = nAtBat nHits nHome nRuns nRBI nBB nOuts nAssts nError
                      CrAtBat CrHits CrHome CrRuns CrRbi CrBB
                      Division_West League_American
                      / details=all stats=all
                        selection=lasso(choose=cv stop=none) cvmethod=random(10) stb;
    score out=score_lasso_cv predicted residual;
    code file="glmselect_lasso_cv.sas";
run;
ods graphics off;
 
data pred_lasso_cv;
    format selected2 logSalary P_logSalary;
    set alldata;
    %include glmselect_lasso_cv;
run;

f:id:cochineal19:20210711162331p:plain:w300

ElasticNet回帰 Hold-out法


ElasticNetもLassoと基本は同じ。
L2正則化項の係数を指定でき、グリッドサーチも可能。selection=elasticnet(choose=validate l2search=grid showl2search) とする。
showl2search を入れると最終決定値を表示できる。

もちろん、係数を直接指定することもできる。例:selection=elasticnet(choose=validate l2=0.5)

*-- ElasticNet回帰 ;
ods graphics on;
proc glmselect data=train plots=all seed=19;
    partition roleVar=selected(train='0' validate='1');
    model logSalary = nAtBat nHits nHome nRuns nRBI nBB nOuts nAssts nError
                      CrAtBat CrHits CrHome CrRuns CrRbi CrBB
                      Division_West League_American
                      / details=all stats=all
                        selection=elasticnet(choose=validate l2search=grid showl2search stop=none) stb ;
    score out=score_elasticnet predicted residual;
    code  file="glmselect_elasticnet_l2grid.sas";
run;
ods graphics off;
 
data pred_elasticnet_l2grid;
    format selected selected2 logSalary P_logSalary;
    set alldata;
    %include glmselect_elasticnet_l2grid;
run;

f:id:cochineal19:20210711163023p:plain:w250 f:id:cochineal19:20210711163109p:plain:w300

ElasticNet回帰 Cross-Validation法


クロスバリデーションについても Lassoと同じ。

*-- ElasticNet回帰;
ods graphics on;
proc glmselect data=train plots=all seed=19;
    model logSalary = nAtBat nHits nHome nRuns nRBI nBB nOuts nAssts nError
                      CrAtBat CrHits CrHome CrRuns CrRbi CrBB
                      Division_West League_American
                      / details=all stats=all
                        selection=elasticnet(choose=cv l2search=grid showl2search stop=none) cvmethod=random(10) stb;
    score out=score_elasticnet predicted residual;
    code  file="glmselect_elasticnet_cv_l2grid.sas";
run;
ods graphics off;
 
data pred_elasticnet_cv_l2grid;
    format selected selected2 logSalary P_logSalary;
    set alldata;
    %include glmselect_elasticnet_cv_l2grid;
run;

f:id:cochineal19:20210711163316p:plain:w250 f:id:cochineal19:20210711163352p:plain:w300

今回は L2=0 でLasso回帰と等価になった。

モデル評価


モデル評価指標を複数並べたいので自作する。

*-- 評価指標;
%macro _Scores(inds=, act=logSalary, pred=P_logSalary, group=);
    proc sql;
        create table &inds._Scores as
            select
                 %if &group.^="" %then %do; &group. , %end;
                 count(*) as N
                ,sum(abs(&pred. - &act.)) / count(*) as MAE label="平均絶対誤差(MAE:Mean Absolute Error)"
                ,sum((&pred. - &act.)**2) / count(*) as MSE label="平均二乗誤差(MSE:Mean Squared Error)"
                ,sqrt(calculated MSE)                as RMSE label="MSEの平方根(RMSE:Root Mean Squared Error)"
                ,sum(abs((&pred. - &act.) / &act.)) / count(*) as MAPE label="平均絶対誤差(MAE:Mean Absolute Error)"
                ,sum(((&pred. - &act.) / &act.)**2) / count(*) as MSPE label="平均二乗パーセント誤差(MSPE:Mean Squared Percentage Error)"
                ,sqrt(calculated MSPE) as RMSE as RMSPE label="平均二乗パーセント誤差の平方根(RMSPE:Root Mean Squared Percentage Error)"
            from &inds.
             %if &group.^="" %then %do; group by &group. %end;
        ;
    quit;
    
    title "&inds.";
    proc print data=&inds._Scores; run;
    title "";
%mend _Scores;
%_Scores(inds=pred_lasso,                group=selected);
%_Scores(inds=pred_lasso_cv,             group=selected2);
%_Scores(inds=pred_elasticnet_l2grid,    group=selected);
%_Scores(inds=pred_elasticnet_cv_l2grid, group=selected2);

f:id:cochineal19:20210711163511p:plain:w500

RMSEで見ると今回はクロスバリデーションで作ったモデルが良好な予測性能。

参考


https://support.sas.com/documentation/onlinedoc/stat/131/glmselect.pdf

本ブログは個人メモです。 本ブログの内容によって生じた損害等の一切の責任を負いかねますのでご了承ください。