我们知道, 在使用倒排索引做召回时, 会应用文本相似度公式打分, 比如Lucene默认的bm25.
注: 目前Lucene实现的时候把分子里的(k1+1)参数去掉了, 去掉(k1+1)不影响排序.
* LUCENE-8563: BM25 scores don't include the (k1+1) factor in their numerator
anymore. This doesn't affect ordering as this is a constant factor which is
the same for every document. (Luca Cavanna via Adrien Grand)
文本相似性分数除了bm25, 还有之前Lucene使用的tfidf等.
这些文本相似性打分公式, 本质都是对下列特征的自定义组合:
bm25和tfidf在传统语料上经过大量测试并表现优异, 但有时对我们特殊的业务场景未必适合. 比如说bm25通过dl参数惩罚长度比较长的文本, 而在有些场景下, 比较长的文本反而代表文档质量比较高.
因此, 我们会有想在不同业务场景自定义文本相似性得分的需求, 在Lucene中可以通过自定义Similarity的方式自定义文本相似性得分, 但是如果我们写死一个自定义公式, 会有下列问题:
综上, 我们决定在Query阶段支持动态自定义公式.
设计思路很简单, 用ES语法来阐述, 比如我们之前定义一个match查询这样写:
{
"match": {
"TITLE": {
"query": "hello world"
}
}
}
这种写法打分时候默认用的是字段对应的文本相似性公式(默认bm25).
我们通过添加一个参数, 支持在query阶段动态调整打分公式:
// 使用tfidf
{
"match": {
"TITLE": {
"query": "hello world",
"similarity": {
"name": "tfidf"
}
}
}
}
// 使用bm25
{
"match": {
"TITLE": {
"query": "hello world",
"similarity": {
"name": "bm25"
}
}
}
}
// 使用自定义公式
{
"match": {
"TITLE": {
"query": "hello world",
"operator": "and",
"similarity": {
"name": "custom",
"expression": "idf*boost*tf/(tf+k*((1-b)+b*dl/avgdl))",
"params": {
"k": 1.2,
"b": 0.75
}
}
}
}
}
其中, name="tfidf", "bm25"是预设的打分公式, 对应Lucene的实现, 而"custom"是我们自己实现的, 支持通过自定义expression动态写公式, 还可以自定义参数.
示例中我们通过自定义的方式自己实现了一个和bm25等效的公式.
expression可以写任何公式, 在公式中可以直接引用下列预设变量, 在运行时会替换为实际值:
要实现的效果清楚了, 接下来说说怎么实现的.
首先我们要知道, Lucene是怎么读到默认的bm25 similarity的:
一个match query会被分解为一个bool query, 套着多个term query.
比如:
{
"match": {
"TITLE": {
"query": "hello world",
"operator": "and"
}
}
}
实际上解析成Lucene Query是这样的逻辑结构:
BoolQuery
must:
TermQuery(TITLE:hello)
TermQuery(TITLE:world)
相似性打分公式是在TermQuery里应用的:
注意图中高亮的代码, 容易看出, Lucene的similarity是定义到IndexSearcher上的, 为啥会这么设计呢? 实际上是因为Similarity这个类, 在Lucene索引阶段也用到了, 因为Similarity打分时用到的dl, 实际是需要在索引阶段存到索引文件的, 而Lucene在实现的时候为了扩展性, 是允许用户通过computeNorm(state)方法自定义dl的计算方法的:
不过目前包括BM25Similarity, TFIDFSimilarity在内的所有Lucene Similarity, 几乎都是相同的实现, 所以如果我们不考虑自定义dl的计算方法, 实际上Similarity是完全可以在Query阶段随意修改的.
从TermQuery的代码来看, TermQuery写死了要从IndexSearcher获得Similarity类, 因此为了可以在Query阶段任意修改, 我们需要自己实现一个TermQuery, 主要就是加一个Similarity参数:
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.similarities.Similarity;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
/**
* A Query that matches documents containing a term. This may be combined with
* other terms with a {@link BooleanQuery}.
* 修改Lucene原生的TermQuery, 允许在Query阶段传入自定义similarity.
* 需要注意是, lucene在索引阶段已调用字段对应similarity的computeNorm()方法计算了norm并储存起来.
* 虽然我们可以自定义Query阶段传入的similarity, 但是最好要保证Query阶段使用的similarity和字段本身的similarity的computeNorm()方法是一致的.
* 一般来说都用bm25的computeNorm()就行了.
*/
public class TermQueryWithSimilarity extends Query {
private final Term term;
private final TermStates perReaderTermState;
// private final String similarity;
private final Similarity similarity;
final class TermWeight extends Weight {
private final Similarity similarity;
private final Similarity.SimScorer simScorer;
private final TermStates termStates;
private final ScoreMode scoreMode;
public TermWeight(IndexSearcher searcher, ScoreMode scoreMode,
float boost, TermStates termStates, Similarity similarity) throws IOException {
super(TermQueryWithSimilarity.this);
if (scoreMode.needsScores() && termStates == null) {
throw new IllegalStateException("termStates are required when scores are needed");
}
this.scoreMode = scoreMode;
this.termStates = termStates;
this.similarity = similarity;
final CollectionStatistics collectionStats;
final TermStatistics termStats;
if (scoreMode.needsScores()) {
collectionStats = searcher.collectionStatistics(term.field());
termStats = termStates.docFreq() > 0 ? searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq()) : null;
} else {
// we do not need the actual stats, use fake stats with docFreq=maxDoc=ttf=1
collectionStats = new CollectionStatistics(term.field(), 1, 1, 1, 1);
termStats = new TermStatistics(term.bytes(), 1, 1);
}
if (termStats == null) {
this.simScorer = null; // term doesn't exist in any segment, we won't use similarity at all
} else {
this.simScorer = similarity.scorer(boost, collectionStats, termStats);
}
}
@Override
public void extractTerms(Set<Term> terms) {
terms.add(getTerm());
}
@Override
public Matches matches(LeafReaderContext context, int doc) throws IOException {
TermsEnum te = getTermsEnum(context);
if (te == null) {
return null;
}
if (context.reader().terms(term.field()).hasPositions() == false) {
return super.matches(context, doc);
}
return MatchesUtils.forField(term.field(), () -> {
PostingsEnum pe = te.postings(null, PostingsEnum.OFFSETS);
if (pe.advance(doc) != doc) {
return null;
}
return new TermMatchesIterator(getQuery(), pe);
});
}
@Override
public String toString() {
return "weight(" + TermQueryWithSimilarity.this + ")";
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
assert termStates == null || termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context)) : "The top-reader used to create Weight is not the same as the current reader's top-reader (" + ReaderUtil.getTopLevelContext(context);
;
final TermsEnum termsEnum = getTermsEnum(context);
if (termsEnum == null) {
return null;
}
LeafSimScorer scorer = new LeafSimScorer(simScorer, context.reader(), term.field(), scoreMode.needsScores());
if (scoreMode == ScoreMode.TOP_SCORES) {
return new TermScorer(this, termsEnum.impacts(PostingsEnum.FREQS), scorer);
} else {
return new TermScorer(this, termsEnum.postings(null, scoreMode.needsScores() ? PostingsEnum.FREQS : PostingsEnum.NONE), scorer);
}
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
/**
* Returns a {@link TermsEnum} positioned at this weights Term or null if
* the term does not exist in the given context
*/
private TermsEnum getTermsEnum(LeafReaderContext context) throws IOException {
assert termStates != null;
assert termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context)) :
"The top-reader used to create Weight is not the same as the current reader's top-reader (" + ReaderUtil.getTopLevelContext(context);
final TermState state = termStates.get(context);
if (state == null) { // term is not present in that reader
assert termNotInReader(context.reader(), term) : "no termstate found but term exists in reader term=" + term;
return null;
}
final TermsEnum termsEnum = context.reader().terms(term.field()).iterator();
termsEnum.seekExact(term.bytes(), state);
return termsEnum;
}
private boolean termNotInReader(LeafReader reader, Term term) throws IOException {
// only called from assert
// System.out.println("TQ.termNotInReader reader=" + reader + " term=" +
// field + ":" + bytes.utf8ToString());
return reader.docFreq(term) == 0;
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
TermScorer scorer = (TermScorer) scorer(context);
if (scorer != null) {
int newDoc = scorer.iterator().advance(doc);
if (newDoc == doc) {
float freq = scorer.freq();
LeafSimScorer docScorer = new LeafSimScorer(simScorer, context.reader(), term.field(), true);
Explanation freqExplanation = Explanation.match(freq, "freq, occurrences of term within document");
Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
return Explanation.match(
scoreExplanation.getValue(),
"weight(" + getQuery() + " in " + doc + ") ["
+ similarity.getClass().getSimpleName() + "], result of:",
scoreExplanation);
}
}
return Explanation.noMatch("no matching term");
}
}
/**
* Constructs a query for the term <code>t</code>.
*/
public TermQueryWithSimilarity(Term t, Similarity similarity) {
term = Objects.requireNonNull(t);
perReaderTermState = null;
this.similarity = similarity;
}
/**
* Expert: constructs a TermQuery that will use the provided docFreq instead
* of looking up the docFreq against the searcher.
*/
public TermQueryWithSimilarity(Term t, TermStates states, Similarity similarity) {
assert states != null;
term = Objects.requireNonNull(t);
perReaderTermState = Objects.requireNonNull(states);
this.similarity = similarity;
}
/**
* Returns the term of this query.
*/
public Term getTerm() {
return term;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
final IndexReaderContext context = searcher.getTopReaderContext();
final TermStates termState;
if (perReaderTermState == null
|| perReaderTermState.wasBuiltFor(context) == false) {
termState = TermStates.build(context, term, scoreMode.needsScores());
} else {
// PRTS was pre-build for this IS
termState = this.perReaderTermState;
}
return new TermWeight(searcher, scoreMode, boost, termState, similarity);
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(term.field())) {
visitor.consumeTerms(this, term);
}
}
/**
* Prints a user-readable version of this query.
*/
@Override
public String toString(String field) {
StringBuilder buffer = new StringBuilder();
if (!term.field().equals(field)) {
buffer.append(term.field());
buffer.append(":");
}
buffer.append(term.text());
return buffer.toString();
}
/**
* Returns the {@link TermStates} passed to the constructor, or null if it was not passed.
*
* @lucene.experimental
*/
public TermStates getTermStates() {
return perReaderTermState;
}
/**
* Returns true iff <code>other</code> is equal to <code>this</code>.
*/
@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
term.equals(((TermQueryWithSimilarity) other).term);
}
@Override
public int hashCode() {
return classHash() ^ term.hashCode();
}
}
我们自定义的TermQueryWithSimilarity大部分代码都是直接从TermQuery复制的, 主要就是添加了similarity属性并从构造方法传入, 使用的时候不再使用IndexSearcher里面的similarity属性而是使用我们传入的.
然后我们修改match query的实现, 在需要生成TermQuery的地方, 使用下面的实现(参数的解析过程略, match query解析生成bool query的过程略):
public static Query getTermQueryWithSimilarity(String field, String text, Similarity similarity) {
if (similarity == null) {
return new TermQuery(new Term(field, text));
}
return new TermQueryWithSimilarity(new Term(field, text), similarity);
}
通过自定义的TermQueryWithSimilarity, 允许我们在Query阶段自定义Similarity了.
那么custom Similarity怎么实现呢? 其实核心问题就是, 怎么获取到需要的tf, idf等需要的特征值呢?
下面给出了一个示例:
package com.zhaopin.solr.search.similarity.custom;
import com.zhaopin.solr.util.Exp4jUtil;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.SmallFloat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Created by jiabao.gao on 2023-04-17.
*/
public class CustomSimilarity extends Similarity {
private static final Similarity BM25_SIM = new BM25Similarity();
private final String expression;
private final Map<String, Float> params;
public CustomSimilarity(String expression, Map<String, Float> params) {
this.expression = expression;
this.params = params;
}
/**
* Implemented as <code>log(1 + (docCount - docFreq + 0.5)/(docFreq + 0.5))</code>.
*/
protected float idf(long docFreq, long docCount) {
return (float) Math.log(1 + (docCount - docFreq + 0.5D) / (docFreq + 0.5D));
}
/**
* The default implementation computes the average as <code>sumTotalTermFreq / docCount</code>
*/
protected float avgFieldLength(CollectionStatistics collectionStats) {
return (float) (collectionStats.sumTotalTermFreq() / (double) collectionStats.docCount());
}
/**
* Cache of decoded bytes.
*/
private static final float[] LENGTH_TABLE = new float[256];
static {
for (int i = 0; i < 256; i++) {
LENGTH_TABLE[i] = SmallFloat.byte4ToInt((byte) i);
}
}
@Override
public long computeNorm(FieldInvertState state) {
return BM25_SIM.computeNorm(state);
}
/**
* Computes a score factor for a simple term and returns an explanation
* for that score factor.
*
* <p>
* The default implementation uses:
*
* <pre class="prettyprint">
* idf(docFreq, docCount);
* </pre>
* <p>
* Note that {@link CollectionStatistics#docCount()} is used instead of
* {@link org.apache.lucene.index.IndexReader#numDocs() IndexReader#numDocs()} because also
* {@link TermStatistics#docFreq()} is used, and when the latter
* is inaccurate, so is {@link CollectionStatistics#docCount()}, and in the same direction.
* In addition, {@link CollectionStatistics#docCount()} does not skew when fields are sparse.
*
* @param collectionStats collection-level statistics
* @param termStats term-level statistics for the term
* @return an Explain object that includes both an idf score factor
* and an explanation for the term.
*/
public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats) {
final long df = termStats.docFreq();
final long docCount = collectionStats.docCount();
final float idf = idf(df, docCount);
return Explanation.match(idf, "idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:",
Explanation.match(df, "n, number of documents containing term"),
Explanation.match(docCount, "N, total number of documents with field"));
}
/**
* Computes a score factor for a phrase.
*
* <p>
* The default implementation sums the idf factor for
* each term in the phrase.
*
* @param collectionStats collection-level statistics
* @param termStats term-level statistics for the terms in the phrase
* @return an Explain object that includes both an idf
* score factor for the phrase and an explanation
* for each term.
*/
public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats[]) {
double idf = 0d; // sum into a double before casting into a float
List<Explanation> details = new ArrayList<>();
for (final TermStatistics stat : termStats) {
Explanation idfExplain = idfExplain(collectionStats, stat);
details.add(idfExplain);
idf += idfExplain.getValue().floatValue();
}
return Explanation.match((float) idf, "idf, sum of:", details);
}
@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
Explanation idf = termStats.length == 1 ? idfExplain(collectionStats, termStats[0]) : idfExplain(collectionStats, termStats);
float avgdl = avgFieldLength(collectionStats);
return new CustomScorer(expression, params, boost, idf, avgdl);
}
private static class CustomScorer extends SimScorer {
private final String expression;
private final Map<String, Float> params;
private final float boost;
private final Explanation idf;
private final float avgdl;
public CustomScorer(String expression, Map<String, Float> params, float boost, Explanation idf, float avgdl) {
this.expression = expression;
// 因为我们打分时会修改参数map, 复制一份, 不要影响原始传入的.
this.params = new HashMap<>(params);
this.boost = boost;
this.idf = idf;
this.avgdl = avgdl;
}
@Override
public float score(float freq, long encodedNorm) {
final float idf = this.idf.getValue().floatValue();
final float dl = LENGTH_TABLE[((byte) encodedNorm) & 0xFF];
params.put("idf", idf);
params.put("boost", boost);
params.put("tf", freq);
params.put("dl", dl);
params.put("avgdl", avgdl);
return Exp4jUtil.eval(expression, params);
}
}
}
可以看出, 我们在打分的阶段, 很容易的拿到了所有需要的属性的值:
最后的Exp4jUtil, 是我们自己封装的执行表达式的库, 基于exp4j项目: https://github.com/fasseg/exp4j.
类似的库应该有很多, exp4j肯定不是最好的, 使用那种预编译动态生成java字节码的执行库性能应该更好. 因为公司内已经有项目在用exp4j了, 我就用了这个.
UPDATE: 经过性能测试, 发现exp4j不能满足性能要求.
尝试换用: paralithic.
整体来说实现不难, 重点就是要明白Lucene一些类的功能与关系:
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。