最近武汉肺炎,牵动着全国人民的心,大家可能都想知道疫情什么时候才能结束,今天我使用2.2之前卫健委发布的确诊人数,按照logistic增长模型拟合了一条曲线并大概预测一下之后的疫情情况,模型的结果仅供参考。
本文基于邢翔瑞博主的文章编写,在此感谢翔瑞作者。原文地址:https://blog.csdn.net/weixin_36474809/article/details/104101055
什么是Logistic增长曲线:
Logistic函数或Logistic曲线是一种常见的S形函数,它是皮埃尔·弗朗索瓦·韦吕勒在1844或1845年在研究它与人口增长的关系时命名的。广义Logistic曲线可以模仿一些情况人口增长(P)的S形曲线。起初阶段大致是指数增长;然后随着开始变得饱和,增加变慢;最后,达到成熟时增加停止。
当一个物种迁入到一个新生态系统中后,其数量会发生变化。假设该物种的起始数量小于环境的最大容纳量,则数量会增长。该物种在此生态系统中有天敌、食物、空间等资源也不足(非理想环境),则增长函数满足逻辑斯谛方程,图像呈S形,此方程是描述在资源有限的条件下种群增长规律的一个最佳数学模型。在以下内容中将具体介绍逻辑斯谛方程的原理、生态学意义及其应用。
Logistic方程,即常微分方程:
而将上面的方程解出来,可以得到logistic函数:
其中为P0初始值,K为终值,r衡量曲线变化快慢,t为时间。
编程实现:
接下来我们就用python来拟合这个曲线。
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
拟合2019-nCov肺炎感染确诊人数
https://blog.csdn.net/z_ccsdn/article/details/104134358
"""
import datetime
import matplotlib.pyplot as plt
import numpy as np
import requests
from matplotlib.font_manager import FontProperties
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error
# 引入中文字体库
font = FontProperties(fname=r"./simsun.ttc", size=14)
sdate = None
hyperparameters_r = None
hyperparameters_K = None
'''
从csv获取数据
csv 样例:
01.11 41 0
01.12 41 0
'''
def load_data_fromcsv(file_path):
data = []
with open(file_path, 'r') as file:
lines = file.readlines()
for line in lines:
line = line.strip()
if line:
line_array = line.split(' ')
if len(line_array) == 3:
data.append(
{'date': line_array[0], 'confirm': line_array[1], 'suspect': line_array[2]})
data.sort(key=lambda x: x["date"])
date_temple = '%Y.%m.%d'
# 获取首次出现感染人数的日期
global sdate
sdate = datetime.datetime.strptime(
'2020.' + data[0]['date'], f'{date_temple}').date()
x_data_history = [datetime.datetime.strptime('2020.' + dd['date'], f'{date_temple}').date().strftime("%m-%d") for dd in
data]
t = [datetime.datetime.strptime(
'2020.' + dd['date'], f'{date_temple}').date() for dd in data]
P_confirm = [int(dd['confirm']) for dd in data]
P_suspect = [int(dd['suspect']) for dd in data]
return np.array(t, dtype=np.datetime64), np.array(P_confirm), np.array(P_suspect), x_data_history
def load_data():
# 拉取腾讯新闻数据
res = requests.get('https://service-n9zsbooc-1252957949.gz.apigw.tencentcs.com/release/qq')
res_json = res.json()
data = res_json['data']['wuwei_ww_cn_day_counts']
# 补充更早些的数据:
data.append({'date': '01.11', 'confirm': '41', 'suspect': '0'})
data.append({'date': '01.12', 'confirm': '41', 'suspect': '0'})
data.sort(key=lambda x: x["date"])
# 因为21号以前并非是全国数据,数据不好要去掉
data = data[10:]
print(data)
# 获取首次出现感染人数的日期
global sdate
sdate = datetime.datetime.strptime('2020.' + data[0]['date'], '%Y.%m/%d').date()
x_data_history = [datetime.datetime.strptime('2020.' + dd['date'], '%Y.%m/%d').date().strftime("%m-%d") for dd in
data]
t = [datetime.datetime.strptime('2020.' + dd['date'], '%Y.%m/%d').date() for dd in data]
P_confirm = [int(dd['confirm']) for dd in data]
P_suspect = [int(dd['suspect']) for dd in data]
return np.array(t, dtype=np.datetime64), np.array(P_confirm), np.array(P_suspect), x_data_history
# 计算相隔天数
def day_delay(t):
t0_date = np.datetime64(sdate, 'D')
t_ = (t - t0_date)
days = (t_ / np.timedelta64(1, 'D')).astype(int)
return days
def logistic_increase_function(t,P0):
r = hyperparameters_r
K = hyperparameters_K
# t:time t0:initial time P0:initial_value K:capacity r:increase_rate
exp_value = np.exp(r * (t))
return (K * exp_value * P0) / (K + (exp_value - 1) * P0)
if __name__ == '__main__':
# 日期及感染人数
t, P_confirm, P_suspect, x_show_data = load_data()
# t, P_confirm, P_suspect, x_show_data = load_data_fromcsv('~/data.csv')
x_data, y_data = day_delay(t), P_confirm
# 分隔训练测试集,将最后的30%数据作为测试集
x_train, x_test, y_train, y_test = x_data[:-1 * int(len(x_data) * 0.3)], x_data[-1 * int(len(x_data) * 0.3):], y_data[:-1 * int(len(x_data) * 0.3)],y_data[-1 * int(len(x_data) * 0.3):]
print(x_train)
print(x_test)
popt = None
mse = float("inf")
r = None
k = None
# 网格搜索来优化r和K参数
max_k = 50000 # 限定的最大感染人数
for k_ in np.arange(20000, max_k, 1):
hyperparameters_K = k_
for r_ in np.arange(0, 1, 0.01):
# 用最小二乘法估计拟合
hyperparameters_r = r_
popt_, pcov_ = curve_fit(logistic_increase_function, x_train, y_train)
# # 获取popt里面是拟合系数
print("K:capacity P0:initial_value r:increase_rate")
print(k_, popt_, r_)
# 计算均方误差对测试集进行验证
mse_ = mean_squared_error(y_test, logistic_increase_function(x_test, *popt_))
print("mse:", mse_)
if mse_ <= mse:
mse = mse_
popt = popt_
r = r_
k = k_
hyperparameters_K = k
hyperparameters_r = r
print("----------------")
print("hyperparameters_K:", hyperparameters_K)
print("hyperparameters_r:", hyperparameters_r)
print("----------------")
popt, pcov = curve_fit(logistic_increase_function, x_data, y_data)
print("K:capacity P0:initial_value r:increase_rate")
print(hyperparameters_K, popt, hyperparameters_r)
# 未来预测
date_nums = 32 #需要预测的总天数,从第一天开始算起
future = np.linspace(0, date_nums, date_nums)
future = np.array(future)
future_predict = logistic_increase_function(future, *popt)
# 绘图
x_show_data_all = [(sdate + (datetime.timedelta(days=fu))).strftime("%m-%d") for fu in future]
plt.scatter(x_show_data, P_confirm, s=35, c='green', marker='.', label="确诊人数")
plt.plot(x_show_data_all, future_predict, 'r-s', marker='+', linewidth=1.5, label='预测曲线')
plt.tick_params(labelsize=5)
plt.xlabel('时间', FontProperties=font)
plt.ylabel('感染人数', FontProperties=font)
plt.xticks(x_show_data_all)
plt.grid() # 显示网格
plt.legend(prop=font) # 指定legend的位置右下角
plt.show()
拟合结果:
2.3日更新拟合结果:
2.4 日更新拟合结果:
2.5 日更新拟合结果:
2.6 日更新拟合结果:
2.7 日更新拟合结果:
2.8 日更新拟合结果:
2.11更新拟合结果:
注意:本文归作者所有,未经作者允许,不得转载