Blog: How to use the openturns-mixmod module to create mixture of experts meta-models: expertsMixture.py

File expertsMixture.py, 4.8 KB (added by regis.lebrun@…, 5 years ago)

Script for mixture of experts based on Mixmod classification

Line 
1from openturns import *
2from otmixmod import *
3from time import *
4
5class ExpertMixture(OpenTURNSPythonFunction):
6    def __init__(self, mixture, metaModels):
7        OpenTURNSPythonFunction.__init__(self, metaModels[0].getInputDimension(), metaModels[0].getOutputDimension())
8        self.atoms_ = mixture.getDistributionCollection()
9        self.localExperts_ = metaModels
10        self.labels_ = labels
11
12    def _exec(self, X):
13        bestLikelihood = -SpecFunc.MaxNumericalScalar
14        bestAtom = -1
15        bestY = -1
16        size = self.atoms_.getSize()
17        for i in range(size):
18            Y = self.localExperts_[i](X)
19            likelihood = self.atoms_[i].getWeight()*self.atoms_[i].computePDF([X[0], Y[0]])
20            if likelihood > bestLikelihood:
21                bestAtom = i
22                bestLikelihood = likelihood
23                bestY = Y
24        return bestY
25   
26size = 2000
27dim = 1
28
29model = NumericalMathFunction("x", "(1.0 + sign(x)) * cos(x) - (sign(x) - 1) * sin(2*x)")
30dataX = Uniform().getNumericalSample(size)
31dataX = dataX.sort()
32dataY = model(dataX)
33# For validation
34dataXValid = Uniform().getNumericalSample(size)
35dataYValid = model(dataXValid)
36
37data =  NumericalSample(size, 2)
38for i in range(size):
39    data[i, 0] = dataX[i, 0]
40    data[i, 1] = dataY[i, 0] + 0*Normal(0.0, 0.2).getRealization()[0]
41
42graph = Graph("Mixture of experts meta-modeling", "X", "Y", True, "topleft")
43c = Curve(data)
44c.setLineStyle("dashed")
45graph.add(c)
46c = Cloud(data)
47c.setLegendName("Learning")
48c.setPointStyle("star")
49graph.add(c)
50graph.draw("DataLearning")
51
52c = Curve(data)
53c.setLineStyle("dashed")
54graph.add(c)
55c = Cloud(dataXValid, dataYValid)
56c.setColor("red")
57c.setLegendName("Validation")
58c.setPointStyle("circle")
59graph.add(c)
60graph.draw("DataValidation")
61
62bestExpert = 0
63bestCluster = 0
64bestError = SpecFunc.MaxNumericalScalar
65k = 7
66kmax = 50
67stop = False
68run = 0
69errorData = NumericalSample(0, 2)
70covModel = Gaussian_pk_Lk_Ck()
71while not stop:
72    t0 = time()
73    print "Try with", k, "cluster(s)"
74    logLike = NumericalPoint(0)
75    labels = Indices(0)
76    # Classify data
77    mixture = MixtureFactory(k, covModel).build(data, labels, logLike)
78    # Build the clusters
79    clusters = MixtureFactory.BuildClusters(data, labels, k)
80    #print "clusters=", clusters
81    #print "labels=", labels
82    # Build the local meta-models
83    metaModels = NumericalMathFunctionCollection(k)
84    for i in range(k):
85        # Extract the distribution of the current cluster
86        distribution = mixture.getDistributionCollection()[i].getMarginal(0)
87        # Build the local meta model using PCE
88        # We use a projection strategy
89        projection = ProjectionStrategy(LeastSquaresStrategy(distribution))
90        # We use an Hermite chaos expansion
91        basis = OrthogonalProductPolynomialFactory(PolynomialFamilyCollection(dim, OrthogonalUniVariatePolynomialFamily(HermiteFactory())))
92        # FixedStrategy
93        if k==1:
94            degree = 16
95        else:
96            degree = 2
97        adaptive = AdaptiveStrategy(OrthogonalBasis(basis), EnumerateFunction().getStrataCumulatedCardinal(degree))
98        algo = FunctionalChaosAlgorithm(clusters[i].getMarginal(0), clusters[i].getMarginal(1), distribution, adaptive, projection)
99        algo.run()
100        metaModels[i] = algo.getResult().getMetaModel()
101    expert = NumericalMathFunction(ExpertMixture(mixture, metaModels))
102    # Validation error
103    dataYMeta = expert(dataXValid)
104    error = 0.0
105    for i in range(dataYMeta.getSize()):
106        error += (dataYValid[i] - dataYMeta[i]).norm2()
107    print "k=", k, "error=", error
108    errorData.add([k, error])
109    errorData.exportToCSVFile("Data_error_" + str(run) + ".csv")
110    if (0*error < bestError):
111        bestExpert = expert
112        bestCluster = k
113        bestError = error
114        gk = Graph(graph)
115        gk.setDrawables(DrawableCollection(0))
116        c = Curve(dataX, expert(dataX))
117        c.setLegendName(str(k) + " expert(s)")
118        c.setColor("red")
119        c.setLineWidth(2)
120        gk.add(c)
121        c = Curve(data)
122        c.setColor("blue")
123        c.setLegendName("Model")
124        c.setLineStyle("dashed")
125        gk.add(c)
126        gk.draw("Mixture_" + str(k).rjust(3, "0") + "_experts_1")       
127        for i in range(k):
128            c = Cloud(clusters[i])
129            c.setPointStyle("star")
130            c.setColor(c.GetValidColors()[i+2])
131            gk.add(c)
132        iso = mixture.drawPDF(data.getMin(), data.getMax(), NumericalPoint(2, 201)).getDrawable(0)
133        iso.setLegendName("Mixture iso-PDF")
134        iso.setColor("green")
135        iso.buildDefaultLevels()
136        print iso.getLevels()
137        dbls = DrawableCollection(1, iso)
138        for i in range(gk.getDrawables().getSize()):
139            dbls.add(gk.getDrawable(i))
140        gk.setDrawables(dbls)
141        gk.draw("Mixture_" + str(k).rjust(3, "0") + "_experts_2")
142
143    stop = (k == kmax)
144    k += 1