”工欲善其事,必先利其器。“—孔子《论语.录灵公》
首页 > 编程 > 为什么现实世界的机器学习需要分布式计算

为什么现实世界的机器学习需要分布式计算

发布于2024-11-08
浏览:175

Why You Need Distributed Computing for Real-World Machine Learning

PySpark 如何帮助您像专业人士一样处理庞大的数据集

PyTorch 和 TensorFlow 等机器学习框架非常适合构建模型。但现实是,当涉及到现实世界的项目时(处理巨大的数据集),您需要的不仅仅是一个好的模型。您需要一种有效处理和管理所有数据的方法。这就是像 PySpark 这样的分布式计算可以拯救世界的地方。

让我们来分析一下为什么在现实世界的机器学习中处理大数据意味着超越 PyTorch 和 TensorFlow,以及 PySpark 如何帮助您实现这一目标。
真正的问题:大数据
您在网上看到的大多数机器学习示例都使用小型、易于管理的数据集。您可以将整个事情放入内存中,进行尝试,并在几分钟内训练一个模型。但在现实场景中(例如信用卡欺诈检测、推荐系统或财务预测),您要处理数百万甚至数十亿行。突然间,您的笔记本电脑或服务器无法处理它。

如果您尝试将所有数据一次性加载到 PyTorch 或 TensorFlow 中,事情将会崩溃。这些框架是为模型训练而设计的,而不是为了有效处理巨大的数据集。这就是分布式计算变得至关重要的地方。
为什么 PyTorch 和 TensorFlow 还不够
PyTorch 和 TensorFlow 非常适合构建和优化模型,但在处理大规模数据任务时却表现不佳。两大问题:

  • 内存过载:他们在训练之前将整个数据集加载到内存中。这适用于小型数据集,但当您拥有 TB 级的数据时,游戏就结束了。
  • 无分布式数据处理:PyTorch 和 TensorFlow 不是为处理分布式数据处理而构建的。如果您有大量数据分布在多台机器上,那么它们并没有真正的帮助。

这就是 PySpark 的闪光点。它旨在处理分布式数据,在多台机器上高效处理数据,同时处理大量数据集,而不会导致系统崩溃。

真实示例:使用 PySpark 检测信用卡欺诈
让我们深入研究一个例子。假设您正在开发使用信用卡交易数据的欺诈检测系统。在本例中,我们将使用 Kaggle 的流行数据集。它包含超过 284,000 笔交易,其中不到 1% 是欺诈交易。

第 1 步:在 Google Colab 中设置 PySpark
为此,我们将使用 Google Colab,因为它允许我们以最少的设置运行 PySpark。

!pip install pyspark

接下来,导入必要的库并启动 Spark 会话。

import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, udf
from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
import numpy as np
from pyspark.sql.types import FloatType

启动 pyspark 会话

spark = SparkSession.builder \
    .appName("FraudDetectionImproved") \
    .master("local[*]") \
    .config("spark.executorEnv.PYTHONHASHSEED", "0") \
    .getOrCreate()

第 2 步:加载和准备数据

data = spark.read.csv('creditcard.csv', header=True, inferSchema=True)
data = data.orderBy("Time")  # Ensure data is sorted by time
data.show(5)
data.describe().show()
# Check for missing values in each column
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

# Prepare the feature columns
feature_columns = data.columns
feature_columns.remove("Class")  # Removing "Class" column as it is our label

# Assemble features into a single vector
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)
data.select("features", "Class").show(5)

# Split data into train (60%), test (20%), and unseen (20%)
train_data, temp_data = data.randomSplit([0.6, 0.4], seed=42)
test_data, unseen_data = temp_data.randomSplit([0.5, 0.5], seed=42)

# Print class distribution in each dataset
print("Train Data:")
train_data.groupBy("Class").count().show()

print("Test and parameter optimisation Data:")
test_data.groupBy("Class").count().show()

print("Unseen Data:")
unseen_data.groupBy("Class").count().show()

第3步:初始化模型

# Initialize RandomForestClassifier
rf = RandomForestClassifier(labelCol="Class", featuresCol="features", probabilityCol="probability")

# Create ParamGrid for Cross Validation
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20 ]) \
    .addGrid(rf.maxDepth, [5, 10]) \
    .build()

# Create 5-fold CrossValidator
crossval = CrossValidator(estimator=rf,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(labelCol="Class", metricName="areaUnderROC"),
                          numFolds=5)

第 4 步:拟合、运行交叉验证,并选择最佳参数集

# Run cross-validation, and choose the best set of parameters
rf_model = crossval.fit(train_data)

# Make predictions on test data
predictions_rf = rf_model.transform(test_data)

# Evaluate Random Forest Model
binary_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
pr_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderPR")

auc_rf = binary_evaluator.evaluate(predictions_rf)
auprc_rf = pr_evaluator.evaluate(predictions_rf)
print(f"Random Forest - AUC: {auc_rf:.4f}, AUPRC: {auprc_rf:.4f}")

# UDF to extract positive probability from probability vector
extract_prob = udf(lambda prob: float(prob[1]), FloatType())
predictions_rf = predictions_rf.withColumn("positive_probability", extract_prob(col("probability")))

Step 5 计算精确率、召回率和 F1 分数的函数

# Function to calculate precision, recall, and F1-score
def calculate_metrics(predictions):
    tp = predictions.filter((col("Class") == 1) & (col("prediction") == 1)).count()
    fp = predictions.filter((col("Class") == 0) & (col("prediction") == 1)).count()
    fn = predictions.filter((col("Class") == 1) & (col("prediction") == 0)).count()

    precision = tp / (tp   fp) if (tp   fp) != 0 else 0
    recall = tp / (tp   fn) if (tp   fn) != 0 else 0
    f1_score = (2 * precision * recall) / (precision   recall) if (precision   recall) != 0 else 0

    return precision, recall, f1_score

第 6 步:找到模型的最佳阈值

# Find the best threshold for the model
best_threshold = 0.5
best_f1 = 0
for threshold in np.arange(0.1, 0.9, 0.1):
    thresholded_predictions = predictions_rf.withColumn("prediction", (col("positive_probability") > threshold).cast("double"))
    precision, recall, f1 = calculate_metrics(thresholded_predictions)

    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"Best threshold: {best_threshold}, Best F1-score: {best_f1:.4f}")

第七步:评估未见过的数据

# Evaluate on unseen data
predictions_unseen = rf_model.transform(unseen_data)
auc_unseen = binary_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {auc_unseen:.4f}")

precision, recall, f1 = calculate_metrics(predictions_unseen)
print(f"Unseen Data - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

area_under_roc = binary_evaluator.evaluate(predictions_unseen)
area_under_pr = pr_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {area_under_roc:.4f}, AUPRC: {area_under_pr:.4f}")

结果

Best threshold: 0.30000000000000004, Best F1-score: 0.9062
Unseen Data - AUC: 0.9384
Unseen Data - Precision: 0.9655, Recall: 0.7568, F1-score: 0.8485
Unseen Data - AUC: 0.9423, AUPRC: 0.8618

然后您可以保存此模型(几KB)并在 pyspark 管道中的任何地方使用它

rf_model.save()

这就是 PySpark 在处理现实机器学习任务中的大型数据集时产生巨大差异的原因:

轻松扩展:PySpark 可以跨集群分配任务,使您能够在不耗尽内存的情况下处理 TB 级的数据。
动态数据处理:PySpark 不需要将整个数据集加载到内存中。它根据需要处理数据,这使得它更加高效。
更快的模型训练:借助分布式计算,您可以通过在多台机器上分配计算工作负载来更快地训练模型。
最后的想法
PyTorch 和 TensorFlow 是构建机器学习模型的绝佳工具,但对于现实世界的大规模任务,您需要更多工具。使用 PySpark 进行分布式计算可让您高效处理庞大数据集、实时处理数据并扩展机器学习管道。

因此,下次您处理海量数据时(无论是欺诈检测、推荐系统还是财务分析),请考虑使用 PySpark 将您的项目提升到一个新的水平。

有关完整的代码和结果,请查看此笔记本。 :
https://colab.research.google.com/drive/1W9naxNZirirLRodSEnHAUWevYd5LH8D4?authuser=5#scrollTo=odmodmqKcY23

__

我是 Swapnil,请随意留下您的评论、结果和想法,或者联系我 - [email protected] 获取数据、软件开发工作和工作信息

版本声明 本文转载于:https://dev.to/femtyfem/why-you-need-distributed-computing-for-real-world-machine-learning-17oo?1如有侵犯,请联系[email protected]删除
最新教程 更多>
  • 在Python中如何创建动态变量?
    在Python中如何创建动态变量?
    在Python 中,动态创建变量的功能可以是一种强大的工具,尤其是在使用复杂的数据结构或算法时,Dynamic Variable Creation的动态变量创建。 Python提供了几种创造性的方法来实现这一目标。利用dictionaries 一种有效的方法是利用字典。字典允许您动态创建密钥并分...
    编程 发布于2025-07-02
  • 为什么我会收到MySQL错误#1089:错误的前缀密钥?
    为什么我会收到MySQL错误#1089:错误的前缀密钥?
    mySQL错误#1089:错误的前缀键错误descript [#1089-不正确的前缀键在尝试在表中创建一个prefix键时会出现。前缀键旨在索引字符串列的特定前缀长度长度,可以更快地搜索这些前缀。了解prefix keys `这将在整个Movie_ID列上创建标准主键。主密钥对于唯一识别...
    编程 发布于2025-07-02
  • 如何将多种用户类型(学生,老师和管理员)重定向到Firebase应用中的各自活动?
    如何将多种用户类型(学生,老师和管理员)重定向到Firebase应用中的各自活动?
    Red: How to Redirect Multiple User Types to Respective ActivitiesUnderstanding the ProblemIn a Firebase-based voting app with three distinct user type...
    编程 发布于2025-07-02
  • C++成员函数指针正确传递方法
    C++成员函数指针正确传递方法
    如何将成员函数置于c 的函数时,接受成员函数指针的函数时,必须同时提供对象的指针,并提供指针和指针到函数。需要具有一定签名的功能指针。要通过成员函数,您需要同时提供对象指针(此)和成员函数指针。这可以通过修改Menubutton :: SetButton()(如下所示:[&& && && &&华)...
    编程 发布于2025-07-02
  • 用户本地时间格式及时区偏移显示指南
    用户本地时间格式及时区偏移显示指南
    在用户的语言环境格式中显示日期/时间,并使用时间偏移在向最终用户展示日期和时间时,以其localzone and格式显示它们至关重要。这确保了不同地理位置的清晰度和无缝用户体验。以下是使用JavaScript实现此目的的方法。方法:推荐方法是处理客户端的Javascript中的日期/时间格式化和时...
    编程 发布于2025-07-02
  • 可以在纯CS中将多个粘性元素彼此堆叠在一起吗?
    可以在纯CS中将多个粘性元素彼此堆叠在一起吗?
    [2这里: https://webthemez.com/demo/sticky-multi-header-scroll/index.html </main> <section> { display:grid; grid-template-...
    编程 发布于2025-07-02
  • Java为何无法创建泛型数组?
    Java为何无法创建泛型数组?
    通用阵列创建错误 arrayList [2]; JAVA报告了“通用数组创建”错误。为什么不允许这样做?答案:Create an Auxiliary Class:public static ArrayList<myObject>[] a = new ArrayList<myO...
    编程 发布于2025-07-02
  • 如何使用node-mysql在单个查询中执行多个SQL语句?
    如何使用node-mysql在单个查询中执行多个SQL语句?
    在node-mysql node-mysql文档最初出于安全原因最初禁用多个语句支持,因为它可能导致SQL注入攻击。要启用此功能,您需要在创建连接时将倍增设置设置为true: var connection = mysql.createconnection({{multipleStatement:...
    编程 发布于2025-07-02
  • 在Java中使用for-to-loop和迭代器进行收集遍历之间是否存在性能差异?
    在Java中使用for-to-loop和迭代器进行收集遍历之间是否存在性能差异?
    For Each Loop vs. Iterator: Efficiency in Collection TraversalIntroductionWhen traversing a collection in Java, the choice arises between using a for-...
    编程 发布于2025-07-02
  • Spark DataFrame添加常量列的妙招
    Spark DataFrame添加常量列的妙招
    在Spark Dataframe ,将常数列添加到Spark DataFrame,该列具有适用于所有行的任意值的Spark DataFrame,可以通过多种方式实现。使用文字值(SPARK 1.3)在尝试提供直接值时,用于此问题时,旨在为此目的的column方法可能会导致错误。 df.withCo...
    编程 发布于2025-07-02
  • Go web应用何时关闭数据库连接?
    Go web应用何时关闭数据库连接?
    在GO Web Applications中管理数据库连接很少,考虑以下简化的web应用程序代码:出现的问题:何时应在DB连接上调用Close()方法?,该特定方案将自动关闭程序时,该程序将在EXITS EXITS EXITS出现时自动关闭。但是,其他考虑因素可能保证手动处理。选项1:隐式关闭终止数...
    编程 发布于2025-07-02
  • 为什么使用Firefox后退按钮时JavaScript执行停止?
    为什么使用Firefox后退按钮时JavaScript执行停止?
    导航历史记录问题:JavaScript使用Firefox Back Back 此行为是由浏览器缓存JavaScript资源引起的。要解决此问题并确保在后续页面访问中执行脚本,Firefox用户应设置一个空功能。 警报'); }; alert('inline Alert')...
    编程 发布于2025-07-02
  • Java是否允许多种返回类型:仔细研究通用方法?
    Java是否允许多种返回类型:仔细研究通用方法?
    在Java中的多个返回类型:一种误解类型:在Java编程中揭示,在Java编程中,Peculiar方法签名可能会出现,可能会出现,使开发人员陷入困境,使开发人员陷入困境。 getResult(string s); ,其中foo是自定义类。该方法声明似乎拥有两种返回类型:列表和E。但这确实是如此吗...
    编程 发布于2025-07-02
  • 如何在Java的全屏独家模式下处理用户输入?
    如何在Java的全屏独家模式下处理用户输入?
    Handling User Input in Full Screen Exclusive Mode in JavaIntroductionWhen running a Java application in full screen exclusive mode, the usual event ha...
    编程 发布于2025-07-02
  • 为什么HTML无法打印页码及解决方案
    为什么HTML无法打印页码及解决方案
    无法在html页面上打印页码? @page规则在@Media内部和外部都无济于事。 HTML:Customization:@page { margin: 10%; @top-center { font-family: sans-serif; font-weight: bo...
    编程 发布于2025-07-02

免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。

Copyright© 2022 湘ICP备2022001581号-3