# Blog: How to use the openturns-mixmod module to create mixture of experts meta-models: expertsMixture.py

File expertsMixture.py, 4.8 KB (added by regis.lebrun@…, 6 years ago) |
---|

Line | |
---|---|

1 | from openturns import * |

2 | from otmixmod import * |

3 | from time import * |

4 | |

5 | class ExpertMixture(OpenTURNSPythonFunction): |

6 | def __init__(self, mixture, metaModels): |

7 | OpenTURNSPythonFunction.__init__(self, metaModels[0].getInputDimension(), metaModels[0].getOutputDimension()) |

8 | self.atoms_ = mixture.getDistributionCollection() |

9 | self.localExperts_ = metaModels |

10 | self.labels_ = labels |

11 | |

12 | def _exec(self, X): |

13 | bestLikelihood = -SpecFunc.MaxNumericalScalar |

14 | bestAtom = -1 |

15 | bestY = -1 |

16 | size = self.atoms_.getSize() |

17 | for i in range(size): |

18 | Y = self.localExperts_[i](X) |

19 | likelihood = self.atoms_[i].getWeight()*self.atoms_[i].computePDF([X[0], Y[0]]) |

20 | if likelihood > bestLikelihood: |

21 | bestAtom = i |

22 | bestLikelihood = likelihood |

23 | bestY = Y |

24 | return bestY |

25 | |

26 | size = 2000 |

27 | dim = 1 |

28 | |

29 | model = NumericalMathFunction("x", "(1.0 + sign(x)) * cos(x) - (sign(x) - 1) * sin(2*x)") |

30 | dataX = Uniform().getNumericalSample(size) |

31 | dataX = dataX.sort() |

32 | dataY = model(dataX) |

33 | # For validation |

34 | dataXValid = Uniform().getNumericalSample(size) |

35 | dataYValid = model(dataXValid) |

36 | |

37 | data = NumericalSample(size, 2) |

38 | for i in range(size): |

39 | data[i, 0] = dataX[i, 0] |

40 | data[i, 1] = dataY[i, 0] + 0*Normal(0.0, 0.2).getRealization()[0] |

41 | |

42 | graph = Graph("Mixture of experts meta-modeling", "X", "Y", True, "topleft") |

43 | c = Curve(data) |

44 | c.setLineStyle("dashed") |

45 | graph.add(c) |

46 | c = Cloud(data) |

47 | c.setLegendName("Learning") |

48 | c.setPointStyle("star") |

49 | graph.add(c) |

50 | graph.draw("DataLearning") |

51 | |

52 | c = Curve(data) |

53 | c.setLineStyle("dashed") |

54 | graph.add(c) |

55 | c = Cloud(dataXValid, dataYValid) |

56 | c.setColor("red") |

57 | c.setLegendName("Validation") |

58 | c.setPointStyle("circle") |

59 | graph.add(c) |

60 | graph.draw("DataValidation") |

61 | |

62 | bestExpert = 0 |

63 | bestCluster = 0 |

64 | bestError = SpecFunc.MaxNumericalScalar |

65 | k = 7 |

66 | kmax = 50 |

67 | stop = False |

68 | run = 0 |

69 | errorData = NumericalSample(0, 2) |

70 | covModel = Gaussian_pk_Lk_Ck() |

71 | while not stop: |

72 | t0 = time() |

73 | print "Try with", k, "cluster(s)" |

74 | logLike = NumericalPoint(0) |

75 | labels = Indices(0) |

76 | # Classify data |

77 | mixture = MixtureFactory(k, covModel).build(data, labels, logLike) |

78 | # Build the clusters |

79 | clusters = MixtureFactory.BuildClusters(data, labels, k) |

80 | #print "clusters=", clusters |

81 | #print "labels=", labels |

82 | # Build the local meta-models |

83 | metaModels = NumericalMathFunctionCollection(k) |

84 | for i in range(k): |

85 | # Extract the distribution of the current cluster |

86 | distribution = mixture.getDistributionCollection()[i].getMarginal(0) |

87 | # Build the local meta model using PCE |

88 | # We use a projection strategy |

89 | projection = ProjectionStrategy(LeastSquaresStrategy(distribution)) |

90 | # We use an Hermite chaos expansion |

91 | basis = OrthogonalProductPolynomialFactory(PolynomialFamilyCollection(dim, OrthogonalUniVariatePolynomialFamily(HermiteFactory()))) |

92 | # FixedStrategy |

93 | if k==1: |

94 | degree = 16 |

95 | else: |

96 | degree = 2 |

97 | adaptive = AdaptiveStrategy(OrthogonalBasis(basis), EnumerateFunction().getStrataCumulatedCardinal(degree)) |

98 | algo = FunctionalChaosAlgorithm(clusters[i].getMarginal(0), clusters[i].getMarginal(1), distribution, adaptive, projection) |

99 | algo.run() |

100 | metaModels[i] = algo.getResult().getMetaModel() |

101 | expert = NumericalMathFunction(ExpertMixture(mixture, metaModels)) |

102 | # Validation error |

103 | dataYMeta = expert(dataXValid) |

104 | error = 0.0 |

105 | for i in range(dataYMeta.getSize()): |

106 | error += (dataYValid[i] - dataYMeta[i]).norm2() |

107 | print "k=", k, "error=", error |

108 | errorData.add([k, error]) |

109 | errorData.exportToCSVFile("Data_error_" + str(run) + ".csv") |

110 | if (0*error < bestError): |

111 | bestExpert = expert |

112 | bestCluster = k |

113 | bestError = error |

114 | gk = Graph(graph) |

115 | gk.setDrawables(DrawableCollection(0)) |

116 | c = Curve(dataX, expert(dataX)) |

117 | c.setLegendName(str(k) + " expert(s)") |

118 | c.setColor("red") |

119 | c.setLineWidth(2) |

120 | gk.add(c) |

121 | c = Curve(data) |

122 | c.setColor("blue") |

123 | c.setLegendName("Model") |

124 | c.setLineStyle("dashed") |

125 | gk.add(c) |

126 | gk.draw("Mixture_" + str(k).rjust(3, "0") + "_experts_1") |

127 | for i in range(k): |

128 | c = Cloud(clusters[i]) |

129 | c.setPointStyle("star") |

130 | c.setColor(c.GetValidColors()[i+2]) |

131 | gk.add(c) |

132 | iso = mixture.drawPDF(data.getMin(), data.getMax(), NumericalPoint(2, 201)).getDrawable(0) |

133 | iso.setLegendName("Mixture iso-PDF") |

134 | iso.setColor("green") |

135 | iso.buildDefaultLevels() |

136 | print iso.getLevels() |

137 | dbls = DrawableCollection(1, iso) |

138 | for i in range(gk.getDrawables().getSize()): |

139 | dbls.add(gk.getDrawable(i)) |

140 | gk.setDrawables(dbls) |

141 | gk.draw("Mixture_" + str(k).rjust(3, "0") + "_experts_2") |

142 | |

143 | stop = (k == kmax) |

144 | k += 1 |