web-dev-qa-db-fra.com

ajustement courbe_fit multivariée dans python

J'essaie d'adapter une fonction simple à deux tableaux de données indépendantes en python. Je comprends que je dois regrouper les données de mes variables indépendantes dans un seul tableau, mais quelque chose semble toujours mal avec la façon dont je passe les variables lorsque j'essaie de faire l'ajustement. (Il y a quelques messages précédents liés à celui-ci, mais ils n'ont pas beaucoup aidé.)

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def fitFunc(x_3d, a, b, c, d):
    return a + b*x_3d[0,:] + c*x_3d[1,:] + d*x_3d[0,:]*x_3d[1,:]

x_3d = np.array([[1,2,3],[4,5,6]])

p0 = [5.11, 3.9, 5.3, 2]

fitParams, fitCovariances = curve_fit(fitFunc, x_3d[:2,:], x_3d[2,:], p0)
print ' fit coefficients:\n', fitParams

L'erreur que je reçois se lit,

raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m)) 
TypeError: Improper input: N=4 must not exceed M=3

Quelle est M la longueur de? N la longueur de p0? Qu'est-ce que je fais mal ici?

18
user3133865

N et M sont définis dans l'aide pour la fonction. N est le nombre de points de données et M est le nombre de paramètres. Votre erreur signifie donc que vous avez besoin d'au moins autant de points de données que de paramètres, ce qui est parfaitement logique.

Ce code fonctionne pour moi:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def fitFunc(x, a, b, c, d):
    return a + b*x[0] + c*x[1] + d*x[0]*x[1]

x_3d = np.array([[1,2,3,4,6],[4,5,6,7,8]])

p0 = [5.11, 3.9, 5.3, 2]

fitParams, fitCovariances = curve_fit(fitFunc, x_3d, x_3d[1,:], p0)
print ' fit coefficients:\n', fitParams

J'ai inclus plus de données. J'ai également modifié fitFunc pour qu'il soit écrit sous une forme qui scanne comme n'étant qu'une fonction d'un seul x - l'installateur se chargera d'appeler cela pour tous les points de données. Le code que vous avez publié faisait également référence à x_3d[2,:], ce qui provoquait une erreur.

21
chthonicdaemon