DNN Example¶
Declare Factory¶
In [ ]:
from ROOT import TMVA, TFile, TTree, TCut, TString
In [ ]:
TMVA.Tools.Instance()
inputFile = TFile.Open("https://raw.githubusercontent.com/iml-wg/tmvatutorials/master/inputdata.root")
outputFile = TFile.Open("TMVAOutputDNN.root", "RECREATE")
factory = TMVA.Factory("TMVAClassification", outputFile,
"!V:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" )
Declare Variables in DataLoader¶
In [ ]:
loader = TMVA.DataLoader("dataset_dnn")
loader.AddVariable("var1")
loader.AddVariable("var2")
loader.AddVariable("var3")
loader.AddVariable("var4")
loader.AddVariable("var5 := var1-var3")
loader.AddVariable("var6 := var1+var2")
Setup Dataset(s)¶
In [ ]:
tsignal = inputFile.Get("Sig")
tbackground = inputFile.Get("Bkg")
loader.AddSignalTree(tsignal)
loader.AddBackgroundTree(tbackground)
loader.PrepareTrainingAndTestTree(TCut(""),
"nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V")
Configure Network Layout¶
In [ ]:
# General layout
layoutString = TString("Layout=TANH|128,TANH|128,TANH|128,LINEAR");
# Training strategies
training0 = TString("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.5+0.5+0.5, Multithreading=True")
training1 = TString("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True")
trainingStrategyString = TString("TrainingStrategy=")
trainingStrategyString += training0 + TString("|") + training1
# General Options
dnnOptions = TString("!H:!V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
"WeightInitialization=XAVIERUNIFORM")
dnnOptions.Append(":")
dnnOptions.Append(layoutString)
dnnOptions.Append(":")
dnnOptions.Append(trainingStrategyString)
Booking Methods¶
In [ ]:
# Standard implementation, no dependencies.
stdOptions = dnnOptions + ":Architecture=CPU"
factory.BookMethod(loader, TMVA.Types.kDNN, "DNN", stdOptions)
# CPU implementation, using BLAS
#cpuOptions = dnnOptions + ":Architecture=CPU"
#factory.BookMethod(loader, TMVA.Types.kDNN, "DNN CPU", cpuOptions)
Train Methods¶
In [ ]:
factory.TrainAllMethods()
Test and Evaluate Methods¶
In [ ]:
factory.TestAllMethods()
factory.EvaluateAllMethods()
Plot ROC Curve¶
We enable JavaScript visualisation for the plots
In [ ]:
%jsroot on
In [ ]:
c = factory.GetROCCurve(loader)
c.Draw()