”工欲善其事,必先利其器。“—孔子《论语.录灵公》
首页 > 编程 > 为什么现实世界的机器学习需要分布式计算

为什么现实世界的机器学习需要分布式计算

发布于2024-09-16
浏览:718

Why You Need Distributed Computing for Real-World Machine Learning

PySpark 如何帮助您像专业人士一样处理庞大的数据集

PyTorch 和 TensorFlow 等机器学习框架非常适合构建模型。但现实是,当涉及到现实世界的项目时(处理巨大的数据集),您需要的不仅仅是一个好的模型。您需要一种有效处理和管理所有数据的方法。这就是像 PySpark 这样的分布式计算可以拯救世界的地方。

让我们来分析一下为什么在现实世界的机器学习中处理大数据意味着超越 PyTorch 和 TensorFlow,以及 PySpark 如何帮助您实现这一目标。
真正的问题:大数据
您在网上看到的大多数机器学习示例都使用小型、易于管理的数据集。您可以将整个事情放入内存中,进行尝试,并在几分钟内训练一个模型。但在现实场景中(例如信用卡欺诈检测、推荐系统或财务预测),您要处理数百万甚至数十亿行。突然间,您的笔记本电脑或服务器无法处理它。

如果您尝试将所有数据一次性加载到 PyTorch 或 TensorFlow 中,事情将会崩溃。这些框架是为模型训练而设计的,而不是为了有效处理巨大的数据集。这就是分布式计算变得至关重要的地方。
为什么 PyTorch 和 TensorFlow 还不够
PyTorch 和 TensorFlow 非常适合构建和优化模型,但在处理大规模数据任务时却表现不佳。两大问题:

  • 内存过载:他们在训练之前将整个数据集加载到内存中。这适用于小型数据集,但当您拥有 TB 级的数据时,游戏就结束了。
  • 无分布式数据处理:PyTorch 和 TensorFlow 不是为处理分布式数据处理而构建的。如果您有大量数据分布在多台机器上,它们并没有真正的帮助。

这就是 PySpark 的闪光点。它旨在处理分布式数据,在多台机器上高效处理数据,同时处理大量数据集,而不会导致系统崩溃。

真实示例:使用 PySpark 检测信用卡欺诈
让我们深入研究一个例子。假设您正在开发使用信用卡交易数据的欺诈检测系统。在本例中,我们将使用 Kaggle 的流行数据集。它包含超过 284,000 笔交易,其中不到 1% 是欺诈交易。

第 1 步:在 Google Colab 中设置 PySpark
为此,我们将使用 Google Colab,因为它允许我们以最少的设置运行 PySpark。

!pip install pyspark

接下来,导入必要的库并启动 Spark 会话。

import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, udf
from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
import numpy as np
from pyspark.sql.types import FloatType

启动 pyspark 会话

spark = SparkSession.builder \
    .appName("FraudDetectionImproved") \
    .master("local[*]") \
    .config("spark.executorEnv.PYTHONHASHSEED", "0") \
    .getOrCreate()

第 2 步:加载和准备数据

data = spark.read.csv('creditcard.csv', header=True, inferSchema=True)
data = data.orderBy("Time")  # Ensure data is sorted by time
data.show(5)
data.describe().show()
# Check for missing values in each column
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

# Prepare the feature columns
feature_columns = data.columns
feature_columns.remove("Class")  # Removing "Class" column as it is our label

# Assemble features into a single vector
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)
data.select("features", "Class").show(5)

# Split data into train (60%), test (20%), and unseen (20%)
train_data, temp_data = data.randomSplit([0.6, 0.4], seed=42)
test_data, unseen_data = temp_data.randomSplit([0.5, 0.5], seed=42)

# Print class distribution in each dataset
print("Train Data:")
train_data.groupBy("Class").count().show()

print("Test and parameter optimisation Data:")
test_data.groupBy("Class").count().show()

print("Unseen Data:")
unseen_data.groupBy("Class").count().show()

第3步:初始化模型

# Initialize RandomForestClassifier
rf = RandomForestClassifier(labelCol="Class", featuresCol="features", probabilityCol="probability")

# Create ParamGrid for Cross Validation
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20 ]) \
    .addGrid(rf.maxDepth, [5, 10]) \
    .build()

# Create 5-fold CrossValidator
crossval = CrossValidator(estimator=rf,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(labelCol="Class", metricName="areaUnderROC"),
                          numFolds=5)

第 4 步:拟合、运行交叉验证,并选择最佳参数集

# Run cross-validation, and choose the best set of parameters
rf_model = crossval.fit(train_data)

# Make predictions on test data
predictions_rf = rf_model.transform(test_data)

# Evaluate Random Forest Model
binary_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
pr_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderPR")

auc_rf = binary_evaluator.evaluate(predictions_rf)
auprc_rf = pr_evaluator.evaluate(predictions_rf)
print(f"Random Forest - AUC: {auc_rf:.4f}, AUPRC: {auprc_rf:.4f}")

# UDF to extract positive probability from probability vector
extract_prob = udf(lambda prob: float(prob[1]), FloatType())
predictions_rf = predictions_rf.withColumn("positive_probability", extract_prob(col("probability")))

Step 5 计算精确率、召回率和 F1 分数的函数

# Function to calculate precision, recall, and F1-score
def calculate_metrics(predictions):
    tp = predictions.filter((col("Class") == 1) & (col("prediction") == 1)).count()
    fp = predictions.filter((col("Class") == 0) & (col("prediction") == 1)).count()
    fn = predictions.filter((col("Class") == 1) & (col("prediction") == 0)).count()

    precision = tp / (tp   fp) if (tp   fp) != 0 else 0
    recall = tp / (tp   fn) if (tp   fn) != 0 else 0
    f1_score = (2 * precision * recall) / (precision   recall) if (precision   recall) != 0 else 0

    return precision, recall, f1_score

第 6 步:找到模型的最佳阈值

# Find the best threshold for the model
best_threshold = 0.5
best_f1 = 0
for threshold in np.arange(0.1, 0.9, 0.1):
    thresholded_predictions = predictions_rf.withColumn("prediction", (col("positive_probability") > threshold).cast("double"))
    precision, recall, f1 = calculate_metrics(thresholded_predictions)

    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"Best threshold: {best_threshold}, Best F1-score: {best_f1:.4f}")

第七步:评估未见过的数据

# Evaluate on unseen data
predictions_unseen = rf_model.transform(unseen_data)
auc_unseen = binary_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {auc_unseen:.4f}")

precision, recall, f1 = calculate_metrics(predictions_unseen)
print(f"Unseen Data - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

area_under_roc = binary_evaluator.evaluate(predictions_unseen)
area_under_pr = pr_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {area_under_roc:.4f}, AUPRC: {area_under_pr:.4f}")

结果

Best threshold: 0.30000000000000004, Best F1-score: 0.9062
Unseen Data - AUC: 0.9384
Unseen Data - Precision: 0.9655, Recall: 0.7568, F1-score: 0.8485
Unseen Data - AUC: 0.9423, AUPRC: 0.8618

然后您可以保存此模型(几KB)并在 pyspark 管道中的任何地方使用它

rf_model.save()

这就是 PySpark 在处理现实机器学习任务中的大型数据集时产生巨大差异的原因:

轻松扩展:PySpark 可以跨集群分配任务,使您能够在不耗尽内存的情况下处理 TB 级的数据。
动态数据处理:PySpark 不需要将整个数据集加载到内存中。它根据需要处理数据,这使得它更加高效。
更快的模型训练:借助分布式计算,您可以通过在多台机器上分配计算工作负载来更快地训练模型。
最后的想法
PyTorch 和 TensorFlow 是构建机器学习模型的绝佳工具,但对于现实世界的大规模任务,您需要更多。使用 PySpark 进行分布式计算可让您高效处理庞大数据集、实时处理数据并扩展机器学习管道。

因此,下次您处理海量数据时(无论是欺诈检测、推荐系统还是财务分析),请考虑使用 PySpark 将您的项目提升到一个新的水平。

有关完整的代码和结果,请查看此笔记本。 :
https://colab.research.google.com/drive/1W9naxNZirirLRodSEnHAUWevYd5LH8D4?authuser=5#scrollTo=odmodmqKcY23

__

我是 Swapnil,请随意留下您的评论、结果和想法,或者联系我 - [email protected] 获取数据、软件开发工作和工作信息

版本声明 本文转载于:https://dev.to/femtyfem/why-you-need-distributed-computing-for-real-world-machine-learning-17oo?1如有侵犯,请联系[email protected]删除
最新教程 更多>
  • Redis 解释:主要功能、用例和实践项目
    Redis 解释:主要功能、用例和实践项目
    Introduction Redis is an open-source, in-memory data structure store used as a database, cache, and message broker. It’s known for its perfor...
    编程 发布于2024-11-06
  • 如何在 macOS 上设置 MySQL 自动启动:开发人员分步指南
    如何在 macOS 上设置 MySQL 自动启动:开发人员分步指南
    作为开发人员,我们经常发现自己在本地计算机上使用 MySQL 数据库。虽然每次系统启动时手动启动 MySQL 是可以管理的,但这可能是一项乏味的任务。在本指南中,我们将逐步介绍将 MySQL 设置为在 macOS 上自动启动的过程,从而节省您的时间并简化您的工作流程。 先决条件 在我...
    编程 发布于2024-11-06
  • 掌握 TypeScript:了解扩展的力量
    掌握 TypeScript:了解扩展的力量
    TypeScript 中的 extends 关键字就像一把瑞士军刀。它用于多种上下文,包括继承、泛型和条件类型。了解如何有效地使用扩展可以生成更健壮、可重用和类型安全的代码。 使用扩展进行继承 extends 的主要用途之一是继承,允许您创建基于现有接口或类的新接口或类。 inter...
    编程 发布于2024-11-06
  • 如何将具有组计数的列添加到 Pandas 中的分组数据框?
    如何将具有组计数的列添加到 Pandas 中的分组数据框?
    如何在Pandas中向分组数据框中添加列在数据分析中,经常需要对数据进行分组并进行计算每组。 Pandas 通过其 groupby 函数提供了一种便捷的方法来做到这一点。一个常见的任务是计算每个组中某一列的值,并将包含这些计数的列添加到数据帧中。考虑数据帧 df:df = pd.DataFrame(...
    编程 发布于2024-11-06
  • 破解编码面试的热门必备书籍(从初级到高级排名)
    破解编码面试的热门必备书籍(从初级到高级排名)
    准备编码面试可能是一个充满挑战的旅程,但拥有正确的资源可以让一切变得不同。无论您是从算法开始的初学者、专注于系统设计的中级开发人员,还是完善编码实践的高级工程师,这份按难度排名的前 10 本书列表都将为您提供成功所需的知识和技能。你的软件工程面试。这些书籍涵盖了从基本算法到系统设计和简洁编码原则的所...
    编程 发布于2024-11-06
  • Java 字符串实习初学者指南
    Java 字符串实习初学者指南
    Java String Interning 引入了通过在共享池中存储唯一字符串来优化内存的概念,减少重复对象。它解释了 Java 如何自动实习字符串文字以及开发人员如何使用 intern() 方法手动将字符串添加到池中。 通过掌握字符串驻留,您可以提高 Java 应用程序的性能和内存效率。要深入...
    编程 发布于2024-11-06
  • 如何在 GUI 应用程序中的不同页面之间共享变量数据?
    如何在 GUI 应用程序中的不同页面之间共享变量数据?
    如何从类中获取变量数据在 GUI 编程环境中,单个应用程序窗口中包含多个页面是很常见的。每个页面可能包含各种小部件,例如输入字段、按钮或标签。当与这些小部件交互时,用户提供输入或做出需要在不同页面之间共享的选择。这就提出了如何从一个类访问另一个类的变量数据的问题,特别是当这些类代表不同的页面时。利用...
    编程 发布于2024-11-06
  • React 中的动态路由
    React 中的动态路由
    React 中的动态路由允许您基于动态数据或参数创建路由,从而在应用程序中实现更灵活、更强大的导航。这对于需要根据用户输入或其他动态因素呈现不同组件的应用程序特别有用。 使用 React Router 设置动态路由 您通常会使用react-router-dom库在React中实现动态路由。这是分步指...
    编程 发布于2024-11-06
  • 大批
    大批
    方法是可以在对象上调用的 fns 数组是对象,因此它们在 JS 中也有方法。 slice(begin):将数组的一部分提取到新数组中,而不改变原始数组。 let arr = ['a','b','c','d','e']; // Usecase: Extract till index p...
    编程 发布于2024-11-06
  • WPF中延迟操作时如何避免UI冻结?
    WPF中延迟操作时如何避免UI冻结?
    WPF 中的延迟操作WPF 中的延迟操作对于增强用户体验和确保平滑过渡至关重要。一种常见的情况是在导航到新窗口之前添加延迟。为了实现此目的,经常使用 Thread.Sleep,如提供的代码片段中所示。但是,在延迟过程中,使用 Thread.Sleep 阻塞 UI 线程会导致 UI 无响应。这表现为在...
    编程 发布于2024-11-06
  • 利用 Java 进行实时数据流和处理
    利用 Java 进行实时数据流和处理
    In today's data-driven world, the ability to process and analyze data in real-time is crucial for businesses to make informed decisions swiftly. Java...
    编程 发布于2024-11-06
  • 如何修复损坏的 InnoDB 表?
    如何修复损坏的 InnoDB 表?
    从 InnoDB 表损坏中恢复灾难性事件可能会导致数据库表严重损坏,特别是 InnoDB 表。遇到这种情况时,了解可用的修复选项就变得至关重要。InnoDB Table Corruption Symptoms查询中描述的症状,包括事务日志中的时间戳错误InnoDB 表的修复策略虽然已经有修复 MyI...
    编程 发布于2024-11-06
  • JavaScript 数组和对象中是否正式允许使用尾随逗号?
    JavaScript 数组和对象中是否正式允许使用尾随逗号?
    数组和对象中的尾随逗号:标准还是容忍?数组和对象中尾随逗号的存在引发了一些关于它们的争论JavaScript 的标准化。这个问题源于在不同浏览器中观察到的不一致行为,特别是旧版本的 Internet Explorer。规范状态根据 ECMAScript 5 规范(第 11.1.5 节) ),对象字面...
    编程 发布于2024-11-06
  • 最佳引导模板生成器
    最佳引导模板生成器
    在当今快速发展的数字环境中,速度和效率是关键,网页设计师和开发人员越来越依赖 Bootstrap 构建器来简化他们的工作流程。这些工具可以快速创建响应灵敏、具有视觉吸引力的网站,使团队能够比以往更快地将他们的想法变为现实。 Bootstrap 构建器真正改变了网站的构建方式,使该过程更加易于访问和高...
    编程 发布于2024-11-06
  • 简化 NestJS 中的文件上传:无需磁盘存储即可高效内存中解析 CSV 和 XLSX
    简化 NestJS 中的文件上传:无需磁盘存储即可高效内存中解析 CSV 和 XLSX
    Effortless File Parsing in NestJS: Manage CSV and XLSX Uploads in Memory for Speed, Security, and Scalability Introduction Handling file uploa...
    编程 发布于2024-11-06

免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。

Copyright© 2022 湘ICP备2022001581号-3