solve_schrodinger.py

Retour Ă  la liste des codes python.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/env python
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np

# Constantes physiques
h = 1
h_ = h/(2*np.pi)

def solve(x_min, x_max, N, mm, Ufunc, Uargs):
    dx = (x_max - x_min) / N
    X = np.linspace(x_min, x_max, N)

    U = Ufunc(X, Uargs)

    # Calcul du hamiltonien, fonction énergie(x),
    # sous représentation matricielle
    # Donne une représentation matricielle de H Psi
    V = np.zeros((N, N)).astype(np.complex)
    k = -h_**2 / (2*mm)
    T = np.zeros((N, N)).astype(np.complex)
    def kron(m, n):
        return m==n
    for m in range(N):
        for n in range(N):
            V[m, n] = U[m] * kron(m, n)
            T[m, n] = k/(dx**2) * (kron(m+1, n) - 2*kron(m, n) + kron(m-1, n))

    H = T+V

    # Calcul des valeurs et vecteurs propres (energies/psi),
    # en diagonalisant H Psi pour trouver les E
    # puis calculs de H Psi = E Psi
    E, Y = np.linalg.eig(H)
    # puis tri
    idx = np.real(E).argsort()
    E = E[idx]
    Y = Y.T[idx]

    return E, Y

def plot(x_min, x_max, N, EE, YY, Ufunc, Uargs, ax=None):
    # Préparation de pyplot
    if ax is None:
        fig, ax = plt.subplots()

    X = np.linspace(x_min, x_max, N)
    U = Ufunc(X, Uargs)
    # Plot du potentiel
    ax.plot(X, U)

    # "Normalisation" des psi pour éviter qu'elles se coupent
    E_sep = np.min(np.abs(EE[1:] - EE[0:-1]))
    YY_max = np.max(np.abs(YY))
    r = E_sep / YY_max / 2.
    YY = r * YY

    for n in range(len(EE)):
        # On récupère l'énergie, et psi
        E = EE[n].real * np.ones_like(X)
        Y = YY[n].real
        Y_max = np.max(np.abs(Y))

        # On trace la valeur de E dans U
        mask = np.where(E > U)
        ax.plot(X[mask], E[mask], 'k--')

        # Puis le psi, décalé vers le haut, en masquant les bords éventuels
        mask = np.where(np.abs(Y) > Y_max*.01)
        ax.plot(X[mask],  E[mask] + Y[mask])

    # On définit des marges pour un joli tracé
    x_pad = (x_max - x_min) * 0.02
    x_min = X[mask][0]
    x_max = X[mask][-1]
    ax.set_xlim(x_min - x_pad, x_max + x_pad)

    y_min = min(U)
    y_max = EE[-1] + max(YY[-1])
    y_pad = (y_max - y_min) * 0.02
    ax.set_ylim(y_min - y_pad, y_max + y_pad)

    # Puis les labels sur les axes
    ax.set_xlabel(r'$x$')
    ax.set_ylabel(r'$U(x)$')

    return ax