Spark的LSH示例代码解析(java)
因为自己的项目需要,所以近期用到了Spark的这个LSH局部哈希算法工具,因为示例代码没啥解释,所以花费了一些时间看官方文档来理解并增加大量注释,希望能给后人带来一些帮助。
示例代码地址:
https://spark.apache.org/docs/3.1.1/ml-features.html#lsh-algorithms

import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.BucketedRandomProjectionLSH;
import org.apache.spark.ml.feature.BucketedRandomProjectionLSHModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.functions.col;
public class SimilarLSHUtil {
// Spark在springboot中应用的探索
public void testFunction() {
// SparkSession 是 Spark SQL 的入口,使用Dataset和DataFrame API编程Spark的入口点。
// 使用 Dataset 或者 Datafram 编写 Spark SQL 应用的时候,第一个要创建的对象就是 SparkSession。
SparkSession spark = SparkSession.builder()
.master("local[*]") // 设置要连接的Spark主URL,例如“ local”在本地运行,“ local [4]”在4核本地运行,或“ spark:// master:7077”在Spark独立集群上运行。
.appName("Spark") // 设置应用程序的名称,该名称将显示在Spark Web UI中。
.getOrCreate(); // 获取一个现有的,SparkSession或者如果不存在,则根据此构建器中设置的选项创建一个新的。
List<Row> dataA = Arrays.asList(
RowFactory.create(0, Vectors.dense(1.0, 1.0)), // Vectors.dense构造密集向量,Vectors.sparse构造稀疏向量
RowFactory.create(1, Vectors.dense(1.0, -1.0)), // RowFactory.create依据给予的参数创建一个row
RowFactory.create(2, Vectors.dense(-1.0, -1.0)), // RowFactory是一个构造Row对象的工厂类
RowFactory.create(3, Vectors.dense(-1.0, 1.0))
);
List<Row> dataB = Arrays.asList(
RowFactory.create(4, Vectors.dense(1.0, 0.0)),
RowFactory.create(5, Vectors.dense(-1.0, 0.0)),
RowFactory.create(6, Vectors.dense(0.0, 1.0)),
RowFactory.create(7, Vectors.dense(0.0, -1.0))
);
// 一个StructType对象,可以有多个StructField,同时也可以用名字(name)来提取,就想当于Map可以用key来提取value
/*
// Case Class(样例类)是一种特殊的类,它们经过优化以被用于模式匹配。
case class StructField(
name: String, // 此字段的名称
dataType: DataType, // 此字段的数据类型
nullable: Boolean = true, // 指示此字段是否可以为null
metadata: Metadata = Metadata.empty // 元数据,此字段的元数据
) {}
*/
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
// 数据集是特定于域的对象的强类型集合,可以使用功能或关系操作并行转换它们。每个数据集还具有一个被称为DataFrame的无类型视图,即Row的一个数据集。
Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
Dataset<Row> dfB = spark.createDataFrame(dataB, schema);
BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH()
.setBucketLength(2.0) // 设置每个哈希存储桶的长度,较大的存储桶可降低误报率。如果将输入向量标准化,则pow(numRecords,-1 / inputDim)的1-10倍是合理的值
.setNumHashTables(3) // LSH OR扩增中使用的哈希表数量的参数。
.setInputCol("features") // 设置输入列的名称
.setOutputCol("hashes"); // 设置输出列的名称
// LSH OR扩增中使用的哈希表数量的参数.LSH OR放大可用于降低假阴性率。该参数的较高值导致降低的误报率,但以增加的计算复杂性为代价。
// https://cloud.tencent.com/developer/article/1035600
// 使用numHashTables = 5,近似最近邻的速度比完全扫描快2倍。在numHashTables = 3的情况下,近似相似连接比完全连接和过滤要快3-5倍。
BucketedRandomProjectionLSHModel model = mh.fit(dfA); // fit(Dataset<?> dataset):使模型适合输入数据
// Feature Transformation 特征转换
// 散列值存储在“hashes”列中的散列数据集
System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':");
// transform(Dataset<?> dataset): 转换输入数据集。
model.transform(dfA).show();
// 计算输入行的位置敏感哈希,然后执行近似相似加入
// 我们可以通过传入已转换的数据集来避免计算哈希,例如:`model.approxSimilarityJoin(transformedA, transformedB, 1.5)`
System.out.println("Approximately joining dfA and dfB on distance smaller than 1.5:");
model.approxSimilarityJoin(dfA, dfB, 1.5, "EuclideanDistance") // 连接两个数据集以大致找到距离小于阈值的所有行对。threshold:阈值。返回:Dataset<?>
.select(col("datasetA.id").alias("idA"), // select(Column... cols)选择一组基于列的表达式。
col("datasetB.id").alias("idB"), // col(String colName),根据列名称选择列,并将其作为返回Column。alias(String alias):为列指定别名。
col("EuclideanDistance")).show(); // show()以表格形式显示数据集的前20行。
// 计算输入行的位置敏感哈希,然后执行近似最近相邻搜索。
// 我们可以通过传入已转换的数据集来避免计算哈希,例如 `model.approxNearestNeighbors(transformedA, key, 2)`
System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:");
Vector key = Vectors.dense(1.0, 0.0);
// 大约在dfA中依据key搜索2个最近邻居
// approxNearestNeighbors(Dataset<?> dataset, Vector key, int numNearestNeighbors, String distCol) 给定一个大型数据集和一个项目,大约可以找到最多k个与该项目距离最近的项目。
model.approxNearestNeighbors(dfA, key, 2).show();
}
}

(以后有时间了,会再出一篇java下的Spark使用配置过程)
(以后有时间了,也会顺便写一个如何在java(特别是Springboot项目)代码中应用这个算法工具)

因为我不是搞算法的,只是用到了而已,所以可能部分地方存在认知错误,求轻喷...