欢迎光临散文网 会员登陆 & 注册

Spark的LSH示例代码解析(java)

2021-05-11 20:34 作者:寂风也过路  | 我要投稿

因为自己的项目需要,所以近期用到了Spark的这个LSH局部哈希算法工具,因为示例代码没啥解释,所以花费了一些时间看官方文档来理解并增加大量注释,希望能给后人带来一些帮助。

示例代码地址:

https://spark.apache.org/docs/3.1.1/ml-features.html#lsh-algorithms

import java.util.Arrays;

import java.util.List;


import org.apache.spark.ml.feature.BucketedRandomProjectionLSH;

import org.apache.spark.ml.feature.BucketedRandomProjectionLSHModel;

import org.apache.spark.ml.linalg.Vector;

import org.apache.spark.ml.linalg.Vectors;

import org.apache.spark.ml.linalg.VectorUDT;

import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.RowFactory;

import org.apache.spark.sql.SparkSession;

import org.apache.spark.sql.types.DataTypes;

import org.apache.spark.sql.types.Metadata;

import org.apache.spark.sql.types.StructField;

import org.apache.spark.sql.types.StructType;


import static org.apache.spark.sql.functions.col;


public class SimilarLSHUtil {

    // Spark在springboot中应用的探索

    public void testFunction() {

        // SparkSession 是 Spark SQL 的入口,使用Dataset和DataFrame API编程Spark的入口点。

        // 使用 Dataset 或者 Datafram 编写 Spark SQL 应用的时候,第一个要创建的对象就是 SparkSession。

        SparkSession spark = SparkSession.builder()

                .master("local[*]")  // 设置要连接的Spark主URL,例如“ local”在本地运行,“ local [4]”在4核本地运行,或“ spark:// master:7077”在Spark独立集群上运行。

                .appName("Spark")   // 设置应用程序的名称,该名称将显示在Spark Web UI中。

                .getOrCreate();     // 获取一个现有的,SparkSession或者如果不存在,则根据此构建器中设置的选项创建一个新的。


        List<Row> dataA = Arrays.asList(

                RowFactory.create(0, Vectors.dense(1.0, 1.0)), // Vectors.dense构造密集向量,Vectors.sparse构造稀疏向量

                RowFactory.create(1, Vectors.dense(1.0, -1.0)), // RowFactory.create依据给予的参数创建一个row

                RowFactory.create(2, Vectors.dense(-1.0, -1.0)), // RowFactory是一个构造Row对象的工厂类

                RowFactory.create(3, Vectors.dense(-1.0, 1.0))

        );


        List<Row> dataB = Arrays.asList(

                RowFactory.create(4, Vectors.dense(1.0, 0.0)),

                RowFactory.create(5, Vectors.dense(-1.0, 0.0)),

                RowFactory.create(6, Vectors.dense(0.0, 1.0)),

                RowFactory.create(7, Vectors.dense(0.0, -1.0))

        );


        // 一个StructType对象,可以有多个StructField,同时也可以用名字(name)来提取,就想当于Map可以用key来提取value

        /*

        // Case Class(样例类)是一种特殊的类,它们经过优化以被用于模式匹配。

        case class StructField(

            name: String,                        // 此字段的名称

            dataType: DataType,                  // 此字段的数据类型

            nullable: Boolean = true,            // 指示此字段是否可以为null

            metadata: Metadata = Metadata.empty  // 元数据,此字段的元数据

        ) {}

        */

        StructType schema = new StructType(new StructField[]{

                new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),

                new StructField("features", new VectorUDT(), false, Metadata.empty())

        });

        // 数据集是特定于域的对象的强类型集合,可以使用功能或关系操作并行转换它们。每个数据集还具有一个被称为DataFrame的无类型视图,即Row的一个数据集。

        Dataset<Row> dfA = spark.createDataFrame(dataA, schema);

        Dataset<Row> dfB = spark.createDataFrame(dataB, schema);


        BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH()

                .setBucketLength(2.0)     // 设置每个哈希存储桶的长度,较大的存储桶可降低误报率。如果将输入向量标准化,则pow(numRecords,-1 / inputDim)的1-10倍是合理的值

                .setNumHashTables(3)      // LSH OR扩增中使用的哈希表数量的参数。

                .setInputCol("features")  // 设置输入列的名称

                .setOutputCol("hashes");  // 设置输出列的名称

// LSH OR扩增中使用的哈希表数量的参数.LSH OR放大可用于降低假阴性率。该参数的较高值导致降低的误报率,但以增加的计算复杂性为代价。

// https://cloud.tencent.com/developer/article/1035600

// 使用numHashTables = 5,近似最近邻的速度比完全扫描快2倍。在numHashTables = 3的情况下,近似相似连接比完全连接和过滤要快3-5倍。


        BucketedRandomProjectionLSHModel model = mh.fit(dfA); // fit(Dataset<?> dataset):使模型适合输入数据


        // Feature Transformation 特征转换

        // 散列值存储在“hashes”列中的散列数据集

        System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':");

        // transform(Dataset<?> dataset): 转换输入数据集。

        model.transform(dfA).show();


        // 计算输入行的位置敏感哈希,然后执行近似相似加入

        // 我们可以通过传入已转换的数据集来避免计算哈希,例如:`model.approxSimilarityJoin(transformedA, transformedB, 1.5)`

        System.out.println("Approximately joining dfA and dfB on distance smaller than 1.5:");

        model.approxSimilarityJoin(dfA, dfB, 1.5, "EuclideanDistance") // 连接两个数据集以大致找到距离小于阈值的所有行对。threshold:阈值。返回:Dataset<?>

                .select(col("datasetA.id").alias("idA"), // select(Column... cols)选择一组基于列的表达式。

                        col("datasetB.id").alias("idB"), // col(String colName),根据列名称选择列,并将其作为返回Column。alias(String alias):为列指定别名。

                        col("EuclideanDistance")).show(); // show()以表格形式显示数据集的前20行。


        // 计算输入行的位置敏感哈希,然后执行近似最近相邻搜索。

        // 我们可以通过传入已转换的数据集来避免计算哈希,例如 `model.approxNearestNeighbors(transformedA, key, 2)`

        System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:");

        Vector key = Vectors.dense(1.0, 0.0);

        // 大约在dfA中依据key搜索2个最近邻居

        // approxNearestNeighbors(Dataset<?> dataset, Vector key, int numNearestNeighbors, String distCol) 给定一个大型数据集和一个项目,大约可以找到最多k个与该项目距离最近的项目。

        model.approxNearestNeighbors(dfA, key, 2).show();

    }

}

(以后有时间了,会再出一篇java下的Spark使用配置过程

(以后有时间了,也会顺便写一个如何在java(特别是Springboot项目用这个算法工具

因为我不是搞算法的,只是用到了而已,所以可能部分地方存在认知错误,求轻喷...

Spark的LSH示例代码解析(java)的评论 (共 条)

分享到微博请遵守国家法律