」工欲善其事,必先利其器。「—孔子《論語.錄靈公》
首頁 > 程式設計 > 為什麼現實世界的機器學習需要分散式運算

為什麼現實世界的機器學習需要分散式運算

發佈於2024-09-16
瀏覽:364

Why You Need Distributed Computing for Real-World Machine Learning

PySpark 如何帮助您像专业人士一样处理庞大的数据集

PyTorch 和 TensorFlow 等机器学习框架非常适合构建模型。但现实是,当涉及到现实世界的项目时(处理巨大的数据集),您需要的不仅仅是一个好的模型。您需要一种有效处理和管理所有数据的方法。这就是像 PySpark 这样的分布式计算可以拯救世界的地方。

让我们来分析一下为什么在现实世界的机器学习中处理大数据意味着超越 PyTorch 和 TensorFlow,以及 PySpark 如何帮助您实现这一目标。
真正的问题:大数据
您在网上看到的大多数机器学习示例都使用小型、易于管理的数据集。您可以将整个事情放入内存中,进行尝试,并在几分钟内训练一个模型。但在现实场景中(例如信用卡欺诈检测、推荐系统或财务预测),您要处理数百万甚至数十亿行。突然间,您的笔记本电脑或服务器无法处理它。

如果您尝试将所有数据一次性加载到 PyTorch 或 TensorFlow 中,事情将会崩溃。这些框架是为模型训练而设计的,而不是为了有效处理巨大的数据集。这就是分布式计算变得至关重要的地方。
为什么 PyTorch 和 TensorFlow 还不够
PyTorch 和 TensorFlow 非常适合构建和优化模型,但在处理大规模数据任务时却表现不佳。两大问题:

  • 内存过载:他们在训练之前将整个数据集加载到内存中。这适用于小型数据集,但当您拥有 TB 级的数据时,游戏就结束了。
  • 无分布式数据处理:PyTorch 和 TensorFlow 不是为处理分布式数据处理而构建的。如果您有大量数据分布在多台机器上,它们并没有真正的帮助。

这就是 PySpark 的闪光点。它旨在处理分布式数据,在多台机器上高效处理数据,同时处理大量数据集,而不会导致系统崩溃。

真实示例:使用 PySpark 检测信用卡欺诈
让我们深入研究一个例子。假设您正在开发使用信用卡交易数据的欺诈检测系统。在本例中,我们将使用 Kaggle 的流行数据集。它包含超过 284,000 笔交易,其中不到 1% 是欺诈交易。

第 1 步:在 Google Colab 中设置 PySpark
为此,我们将使用 Google Colab,因为它允许我们以最少的设置运行 PySpark。

!pip install pyspark

接下来,导入必要的库并启动 Spark 会话。

import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, udf
from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
import numpy as np
from pyspark.sql.types import FloatType

启动 pyspark 会话

spark = SparkSession.builder \
    .appName("FraudDetectionImproved") \
    .master("local[*]") \
    .config("spark.executorEnv.PYTHONHASHSEED", "0") \
    .getOrCreate()

第 2 步:加载和准备数据

data = spark.read.csv('creditcard.csv', header=True, inferSchema=True)
data = data.orderBy("Time")  # Ensure data is sorted by time
data.show(5)
data.describe().show()
# Check for missing values in each column
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

# Prepare the feature columns
feature_columns = data.columns
feature_columns.remove("Class")  # Removing "Class" column as it is our label

# Assemble features into a single vector
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)
data.select("features", "Class").show(5)

# Split data into train (60%), test (20%), and unseen (20%)
train_data, temp_data = data.randomSplit([0.6, 0.4], seed=42)
test_data, unseen_data = temp_data.randomSplit([0.5, 0.5], seed=42)

# Print class distribution in each dataset
print("Train Data:")
train_data.groupBy("Class").count().show()

print("Test and parameter optimisation Data:")
test_data.groupBy("Class").count().show()

print("Unseen Data:")
unseen_data.groupBy("Class").count().show()

第3步:初始化模型

# Initialize RandomForestClassifier
rf = RandomForestClassifier(labelCol="Class", featuresCol="features", probabilityCol="probability")

# Create ParamGrid for Cross Validation
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20 ]) \
    .addGrid(rf.maxDepth, [5, 10]) \
    .build()

# Create 5-fold CrossValidator
crossval = CrossValidator(estimator=rf,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(labelCol="Class", metricName="areaUnderROC"),
                          numFolds=5)

第 4 步:拟合、运行交叉验证,并选择最佳参数集

# Run cross-validation, and choose the best set of parameters
rf_model = crossval.fit(train_data)

# Make predictions on test data
predictions_rf = rf_model.transform(test_data)

# Evaluate Random Forest Model
binary_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
pr_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderPR")

auc_rf = binary_evaluator.evaluate(predictions_rf)
auprc_rf = pr_evaluator.evaluate(predictions_rf)
print(f"Random Forest - AUC: {auc_rf:.4f}, AUPRC: {auprc_rf:.4f}")

# UDF to extract positive probability from probability vector
extract_prob = udf(lambda prob: float(prob[1]), FloatType())
predictions_rf = predictions_rf.withColumn("positive_probability", extract_prob(col("probability")))

Step 5 计算精确率、召回率和 F1 分数的函数

# Function to calculate precision, recall, and F1-score
def calculate_metrics(predictions):
    tp = predictions.filter((col("Class") == 1) & (col("prediction") == 1)).count()
    fp = predictions.filter((col("Class") == 0) & (col("prediction") == 1)).count()
    fn = predictions.filter((col("Class") == 1) & (col("prediction") == 0)).count()

    precision = tp / (tp   fp) if (tp   fp) != 0 else 0
    recall = tp / (tp   fn) if (tp   fn) != 0 else 0
    f1_score = (2 * precision * recall) / (precision   recall) if (precision   recall) != 0 else 0

    return precision, recall, f1_score

第 6 步:找到模型的最佳阈值

# Find the best threshold for the model
best_threshold = 0.5
best_f1 = 0
for threshold in np.arange(0.1, 0.9, 0.1):
    thresholded_predictions = predictions_rf.withColumn("prediction", (col("positive_probability") > threshold).cast("double"))
    precision, recall, f1 = calculate_metrics(thresholded_predictions)

    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"Best threshold: {best_threshold}, Best F1-score: {best_f1:.4f}")

第七步:评估未见过的数据

# Evaluate on unseen data
predictions_unseen = rf_model.transform(unseen_data)
auc_unseen = binary_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {auc_unseen:.4f}")

precision, recall, f1 = calculate_metrics(predictions_unseen)
print(f"Unseen Data - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

area_under_roc = binary_evaluator.evaluate(predictions_unseen)
area_under_pr = pr_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {area_under_roc:.4f}, AUPRC: {area_under_pr:.4f}")

结果

Best threshold: 0.30000000000000004, Best F1-score: 0.9062
Unseen Data - AUC: 0.9384
Unseen Data - Precision: 0.9655, Recall: 0.7568, F1-score: 0.8485
Unseen Data - AUC: 0.9423, AUPRC: 0.8618

然后您可以保存此模型(几KB)并在 pyspark 管道中的任何地方使用它

rf_model.save()

这就是 PySpark 在处理现实机器学习任务中的大型数据集时产生巨大差异的原因:

轻松扩展:PySpark 可以跨集群分配任务,使您能够在不耗尽内存的情况下处理 TB 级的数据。
动态数据处理:PySpark 不需要将整个数据集加载到内存中。它根据需要处理数据,这使得它更加高效。
更快的模型训练:借助分布式计算,您可以通过在多台机器上分配计算工作负载来更快地训练模型。
最后的想法
PyTorch 和 TensorFlow 是构建机器学习模型的绝佳工具,但对于现实世界的大规模任务,您需要更多工具。使用 PySpark 进行分布式计算可让您高效处理庞大数据集、实时处理数据并扩展机器学习管道。

因此,下次您处理海量数据时(无论是欺诈检测、推荐系统还是财务分析),请考虑使用 PySpark 将您的项目提升到一个新的水平。

有关完整的代码和结果,请查看此笔记本。 :
https://colab.research.google.com/drive/1W9naxNZirirLRodSEnHAUWevYd5LH8D4?authuser=5#scrollTo=odmodmqKcY23

__

我是 Swapnil,请随意留下您的评论、结果和想法,或者联系我 - [email protected] 获取数据、软件开发工作和工作信息

版本聲明 本文轉載於:https://dev.to/femtyfem/why-you-need-distributed-computing-for-real-world-machine-learning-17oo?1如有侵犯,請聯絡[email protected]刪除
最新教學 更多>
  • 如何解決 JLabel 拖放的滑鼠事件衝突?
    如何解決 JLabel 拖放的滑鼠事件衝突?
    用於拖放的JLabel 滑鼠事件:解決滑鼠事件衝突為了在JLabel 上啟用拖放功能,滑鼠事件必須被覆蓋。然而,當嘗試使用 mousePressed 事件實作拖放時,會出現一個常見問題,因為 mouseReleased 事件對該 JLabel 無效。 提供的程式碼在 mousePressed 事件中...
    程式設計 發佈於2024-11-06
  • MySQL 中的資料庫分片:綜合指南
    MySQL 中的資料庫分片:綜合指南
    随着数据库变得越来越大、越来越复杂,有效地控制性能和扩展就出现了。数据库分片是用于克服这些障碍的一种方法。称为“分片”的数据库分区将大型数据库划分为更小、更易于管理的段(称为“分片”)。通过将每个分片分布在多个服务器上(每个服务器保存总数据的一小部分),可以提高可扩展性和吞吐量。 在本文中,我们将探...
    程式設計 發佈於2024-11-06
  • 如何將 Python 日期時間物件轉換為秒?
    如何將 Python 日期時間物件轉換為秒?
    在Python 中將日期時間物件轉換為秒在Python 中使用日期時間物件時,通常需要將它們轉換為秒以適應各種情況分析目的。但是,toordinal() 方法可能無法提供所需的輸出,因為它僅區分具有不同日期的日期。 要準確地將日期時間物件轉換為秒,特別是對於 1970 年 1 月 1 日的特定日期,...
    程式設計 發佈於2024-11-06
  • 如何使用 Laravel Eloquent 的 firstOrNew() 方法有效最佳化 CRUD 操作?
    如何使用 Laravel Eloquent 的 firstOrNew() 方法有效最佳化 CRUD 操作?
    使用 Laravel Eloquent 優化 CRUD 操作在 Laravel 中使用資料庫時,插入或更新記錄是很常見的。為了實現這一點,開發人員經常求助於條件語句,在決定執行插入或更新之前檢查記錄是否存在。 firstOrNew() 方法幸運的是, Eloquent 透過firstOrNew() ...
    程式設計 發佈於2024-11-06
  • 為什麼在 PHP 中重寫方法參數違反了嚴格的標準?
    為什麼在 PHP 中重寫方法參數違反了嚴格的標準?
    在PHP 中重寫方法參數:違反嚴格標準在物件導向程式設計中,里氏替換原則(LSP) 規定:子類型的物件可以替換其父對象,而不改變程式的行為。然而,在 PHP 中,用不同的參數簽名覆蓋方法被認為是違反嚴格標準的。 為什麼這是違規? PHP 是弱型別語言,這表示編譯器無法在編譯時確定變數的確切型別。這表...
    程式設計 發佈於2024-11-06
  • 哪個 PHP 函式庫提供卓越的 SQL 注入防護:PDO 還是 mysql_real_escape_string?
    哪個 PHP 函式庫提供卓越的 SQL 注入防護:PDO 還是 mysql_real_escape_string?
    PDO vs. mysql_real_escape_string:綜合指南查詢轉義對於防止 SQL 注入至關重要。雖然 mysql_real_escape_string 提供了轉義查詢的基本方法,但 PDO 成為了一種具有眾多優點的卓越解決方案。 什麼是 PDO? PHP 資料物件 (PDO) 是一...
    程式設計 發佈於2024-11-06
  • React 入門:初學者的路線圖
    React 入門:初學者的路線圖
    大家好! ? 我剛開始學習 React.js 的旅程。這是一次令人興奮(有時甚至具有挑戰性!)的冒險,我想分享一下幫助我開始的步驟,以防您也開始研究 React。這是我的處理方法: 1.掌握 JavaScript 基礎 在開始使用 React 之前,我確保溫習一下我的 JavaScript 技能,...
    程式設計 發佈於2024-11-06
  • 如何引用 JavaScript 物件中的內部值?
    如何引用 JavaScript 物件中的內部值?
    如何在JavaScript 物件中引用內部值在JavaScript 中,存取引用同一物件中其他值的物件中的值有時可能具有挑戰性。考慮以下程式碼片段:var obj = { key1: "it ", key2: key1 " works!" }; a...
    程式設計 發佈於2024-11-06
  • Python 列表方法快速指南及範例
    Python 列表方法快速指南及範例
    介紹 Python 清單用途廣泛,並附帶各種內建方法,有助於有效地操作和處理資料。以下是所有主要清單方法的快速參考以及簡短的範例。 1. 追加(項目) 將項目新增至清單末端。 lst = [1, 2, 3] lst.append(4) # [1, 2, 3, ...
    程式設計 發佈於2024-11-06
  • C++ 中何時需要使用者定義的複製建構函式?
    C++ 中何時需要使用者定義的複製建構函式?
    何時需要使用者定義的複製建構子? 複製建構子是 C 物件導向程式設計的組成部分,提供了一種基於現有實例初始化物件的方法。雖然編譯器通常會為類別產生預設的複製建構函數,但在某些情況下需要進行自訂。 需要使用者定義複製建構子的情況當預設複製建構子不夠時,程式設計師會選擇使用者定義的複製建構子來實作自訂複...
    程式設計 發佈於2024-11-06
  • 試...捕捉 V/s 安全分配 (?=):現代發展的福音還是詛咒?
    試...捕捉 V/s 安全分配 (?=):現代發展的福音還是詛咒?
    最近,我發現了 JavaScript 中引入的新安全賦值運算子 (?.=),我對它的簡單性著迷。 ? 安全賦值運算子 (SAO) 是傳統 try...catch 區塊的簡寫替代方案。它允許您內聯捕獲錯誤,而無需為每個操作編寫明確的錯誤處理程式碼。這是一個例子: const [error, resp...
    程式設計 發佈於2024-11-06
  • 如何在Python中優化固定寬度檔案解析?
    如何在Python中優化固定寬度檔案解析?
    優化固定寬度文件解析為了有效地解析固定寬度文件,可以考慮利用Python的struct模組。此方法利用 C 來提高速度,如下例所示:import struct fieldwidths = (2, -10, 24) fmtstring = ' '.join('{}{}'.format(abs(fw),...
    程式設計 發佈於2024-11-06
  • 蠅量級
    蠅量級
    結構模式之一旨在透過與相似物件共享盡可能多的資料來減少記憶體使用。 在處理大量相似物件時特別有用,為每個物件建立一個新實例在記憶體消耗方面會非常昂貴。 關鍵概念: 內在狀態:多個物件之間共享的狀態獨立於上下文,並且在不同物件之間保持相同。 外部狀態:每個物件唯一的、從客戶端傳遞的狀態。此狀態可...
    程式設計 發佈於2024-11-06
  • 解鎖您的 MySQL 掌握:MySQL 實作實驗室課程
    解鎖您的 MySQL 掌握:MySQL 實作實驗室課程
    透過全面的 MySQL 實作實驗室課程提升您的 MySQL 技能並成為資料庫專家。這種實踐學習體驗旨在引導您完成一系列實踐練習,使您能夠克服複雜的 SQL 挑戰並優化資料庫效能。 深入了解 MySQL 無論您是想要建立強大 MySQL 基礎的初學者,還是想要提升專業知識的經驗豐富的...
    程式設計 發佈於2024-11-06

免責聲明: 提供的所有資源部分來自互聯網,如果有侵犯您的版權或其他權益,請說明詳細緣由並提供版權或權益證明然後發到郵箱:[email protected] 我們會在第一時間內為您處理。

Copyright© 2022 湘ICP备2022001581号-3