Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ abstract class AccessBasedStorage<S : AccessBasedStorage<S>> {
}

pattern.forEachAccessor { accessor, accessorPattern ->
children[accessor]?.collectNodesContains(accessorPattern, nodes)
collectNodesContainsAccessor(accessorPattern, accessor, nodes)
}
}

open fun collectNodesContainsAccessor(
pattern: AccessTree.AccessNode,
accessor: Accessor,
nodes: MutableList<S>
) {
children[accessor]?.collectNodesContains(pattern, nodes)
}

fun allNodes(): Sequence<S> {
val storages = mutableListOf<S>()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.seqra.dataflow.ap.ifds.access.tree

import org.seqra.dataflow.ap.ifds.Accessor
import org.seqra.dataflow.ap.ifds.AnyAccessor
import org.seqra.dataflow.ap.ifds.ExclusionSet
import org.seqra.dataflow.ap.ifds.access.common.CommonF2FSummary
import org.seqra.dataflow.ap.ifds.access.common.CommonF2FSummary.F2FBBuilder
Expand Down Expand Up @@ -31,6 +33,19 @@ private class MethodTaintedSummariesInitialApStorage(
}
}

override fun collectNodesContainsAccessor(
pattern: AccessTreeNode,
accessor: Accessor,
nodes: MutableList<MethodTaintedSummariesInitialApStorage>
) {
if (accessor is AnyAccessor) {
nodes += allNodes()
return
}

super.collectNodesContainsAccessor(pattern, accessor, nodes)
}

fun collectAllSummariesTo(dst: MutableList<F2FBBuilder<AccessPath.AccessNode?, AccessTree.AccessNode>>) {
allNodes().forEach { node ->
node.current?.summaries()?.let { dst.add(it) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.seqra.dataflow.ap.ifds.trace

import org.seqra.dataflow.ap.ifds.AccessPathBase
import org.seqra.dataflow.ap.ifds.MethodEntryPoint
import org.seqra.dataflow.ap.ifds.TaintAnalysisUnitRunnerManager
import org.seqra.dataflow.ap.ifds.taint.TaintSinkTracker
Expand All @@ -21,9 +20,23 @@ class TraceResolver(
val resolveEntryPointToStartTrace: Boolean = true,
val startToSourceTraceResolutionLimit: Int? = null,
val startToSinkTraceResolutionLimit: Int? = null,
val sourceToSinkInnerTraceResolutionLimit: Int? = null
val sourceToSinkInnerTraceResolutionLimit: Int? = null,
val innerCallTraceResolveStrategy: InnerCallTraceResolveStrategy = InnerCallTraceResolveStrategy.Default,
)

interface InnerCallTraceResolveStrategy {
fun innerCallTraceIsRelevant(callSummary: TraceEntryAction.CallSummary): Boolean =
callSummary.summaryEdges.any { innerCallSummaryEdgeIsRelevant(it) }

fun innerCallSummaryEdgeIsRelevant(summaryEdge: TraceEntryAction.TraceSummaryEdge): Boolean =
when (summaryEdge) {
is TraceEntryAction.TraceSummaryEdge.SourceSummary -> true
is TraceEntryAction.TraceSummaryEdge.MethodSummary -> summaryEdge.edge.fact != summaryEdge.edgeAfter.fact
}

object Default: InnerCallTraceResolveStrategy
}

data class Trace(
val entryPointToStart: EntryPointToStartTrace?,
val sourceToSinkTrace: SourceToSinkTrace,
Expand Down Expand Up @@ -359,6 +372,12 @@ class TraceResolver(
}

addInnerTraces(fullTrace, innerDepth)

if (kind == CallKind.CallInnerTrace) {
resultNodes += InterProceduralFullTraceNode(fullTrace)
continue
}

when (val start = fullTrace.startEntry) {
is SourceStartEntry -> {
resultNodes += resolveNode(fullTrace, kind, depth)
Expand Down Expand Up @@ -444,15 +463,6 @@ class TraceResolver(
}
}

private fun TraceEntryAction.CallSummary.isRelevantCall(): Boolean = summaryEdges.any {
if (it.edge.fact.base is AccessPathBase.ClassStatic) return@any false

when (it) {
is TraceEntryAction.TraceSummaryEdge.SourceSummary -> true
is TraceEntryAction.TraceSummaryEdge.MethodSummary -> it.edge.fact != it.edgeAfter.fact
}
}

private fun addInnerTraces(trace: MethodTraceResolver.FullTrace, depth: Int) {
if (params.sourceToSinkInnerTraceResolutionLimit != null) {
if (depth > params.sourceToSinkInnerTraceResolutionLimit) {
Expand All @@ -469,7 +479,7 @@ class TraceResolver(

val action = entry.primaryAction
if (action !is TraceEntryAction.CallSummary) continue
if (!action.isRelevantCall()) continue
if (!params.innerCallTraceResolveStrategy.innerCallTraceIsRelevant(action)) continue

val summary = action.summaryTrace
addUnprocessedEvent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,6 @@ class JIRFactTypeChecker(private val cp: JIRClasspath) : FactTypeChecker {
is FieldAccessor -> {
if (actualType !is JIRRefType) return FilterResult.Reject

if (accessor == badElementAccessor) {
if (elementAccessors.any { checkAccessor(it) is FilterResult.Accept }) {
return FilterResult.Accept
}
return FilterResult.Reject
}

if (accessor == badMapKeyAccessor) {
if (mapKeyAccessors.any { checkAccessor(it) is FilterResult.Accept }) {
return FilterResult.Accept
}
return FilterResult.Reject
}

if (accessor == badMapValueAccessor) {
if (mapValueAccessors.any { checkAccessor(it) is FilterResult.Accept }) {
return FilterResult.Accept
}
return FilterResult.Reject
}

val factType = fieldClassType(accessor) ?: return FilterResult.Accept
if (!typeMayHaveSubtypeOf(actualType, factType)) return FilterResult.Reject
return FilterResult.Accept
Expand Down Expand Up @@ -273,33 +252,4 @@ class JIRFactTypeChecker(private val cp: JIRClasspath) : FactTypeChecker {
}
return false
}

// todo: fix config
private val badElementAccessor = FieldAccessor("java.lang.Object", "Element", "java.lang.Object")

private val elementBases = listOf(
"java.lang.Iterable",
"java.util.Iterator",
"java.util.Optional",
)

private val elementAccessors = elementBases.map {
FieldAccessor(it, "Element", "java.lang.Object")
}

private val mapBases = listOf(
"java.util.Map",
"java.util.Map\$Entry",
"org.springframework.http.ResponseEntity"
)

private val badMapKeyAccessor = FieldAccessor("java.lang.Object", "MapKey", "java.lang.Object")
private val mapKeyAccessors = mapBases.map {
FieldAccessor(it, "MapKey", "java.lang.Object")
}

private val badMapValueAccessor = FieldAccessor("java.lang.Object", "MapValue", "java.lang.Object")
private val mapValueAccessors = mapBases.map {
FieldAccessor(it, "MapValue", "java.lang.Object")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class LambdaAnonymousClassFeature : JIRClasspathExtFeature {

val lambdaClass = JIRLambdaClass(
lambdaClassName, declaredFields, declaredMethods,
lambdaMethod.method, lambdaMethod.method.enclosingClass
lambdaMethod.method, lambdaMethod.method.enclosingClass, location
).also {
val locationClass = location.method.enclosingClass
it.bindWithLocation(locationClass.classpath, locationClass.declaration.location)
Expand Down Expand Up @@ -304,7 +304,8 @@ class LambdaAnonymousClassFeature : JIRClasspathExtFeature {
fields: List<JIRVirtualField>,
methods: List<JIRVirtualMethod>,
val lambdaMethod: JIRMethod,
private val lambdaInterfaceType: JIRClassOrInterface
val lambdaInterfaceType: JIRClassOrInterface,
val lambdaLocation: JIRInstLocation
) : JIRVirtualClassImpl(name, initialFields = fields, initialMethods = methods) {

private lateinit var declarationLocation: RegisteredLocation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class JIRMethodCallSummaryHandler(
}

analysisContext.aliasAnalysis?.forEachAliasAfterStatement(statement, summaryFactAp) { aliased ->
handleSummaryEdge(initialFactRefinement, aliased)
result += handleSummaryEdge(initialFactRefinement, aliased)
}

handleSummaryEdge(initialFactRefinement, summaryFactAp)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.seqra.dataflow.jvm.ap.ifds.analysis

import org.seqra.dataflow.ap.ifds.Accessor
import org.seqra.dataflow.ap.ifds.AnalysisRunner
import org.seqra.dataflow.ap.ifds.AnyAccessor
import org.seqra.dataflow.ap.ifds.ExclusionSet
import org.seqra.dataflow.ap.ifds.MethodSummaryEdgeApplicationUtils
import org.seqra.dataflow.ap.ifds.SideEffectKind
Expand Down Expand Up @@ -42,7 +44,20 @@ class JIRMethodSideEffectHandler(
val allAccessors = delta.getAllAccessors()
if (mark !in allAccessors) return

val relevantStartAccessors = delta.getStartAccessors().filter { accessor ->
val startAccessors = hashSetOf<Accessor>()
for (accessor in delta.getStartAccessors()) {
if (accessor !is AnyAccessor) {
startAccessors.add(accessor)
continue
}

val anySuccessors = delta.readAccessor(accessor)?.getStartAccessors()
?: continue

anySuccessors.filterTo(startAccessors) { it !is AnyAccessor }
}

val relevantStartAccessors = startAccessors.filter { accessor ->
accessor == mark || delta.readAccessor(accessor)?.getAllAccessors()?.contains(mark) ?: false
}

Expand Down
Loading