Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 42 additions & 45 deletions src/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,58 +133,55 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)

@assert !isnothing(d)

target_indices = eliminate_indices(vcat(get_free_indices(arg1), get_indices(arg2)))
factors = collect_factors(arg1)
vector_factors = filter(f -> f != d, factors)
other_factors = filter(f -> f != d, factors)
all_factors = vcat(filter(f -> f != d, factors), collect_factors(arg2))
reshaped = []

for f ∈ factors
if isequal(f, d)
continue
end

free_ids = get_free_indices(f)

if isempty(free_ids)
push!(reshaped, f)
elseif length(free_ids) == 1 || length(free_ids) == 2
@assert length(target_indices) == 1

vector_index =
only(get_free_indices(to_binary_operation(Mult(), vector_factors)))

current_idx = intersect(free_ids, [vector_index])

if !isempty(current_idx)
f = update_index(
f,
vector_index,
only(target_indices);
allow_shape_change = true,
)
el1 = eliminated_indices([get_free_indices(d); get_free_indices(arg2)])
ic1 = indices_in_common(d, to_binary_operation(Mult(), other_factors))
el2 = eliminated_indices(
[
get_free_indices(d);
get_free_indices(to_binary_operation(Mult(), other_factors))
],
)
ic2 = indices_in_common(d, arg2)

if !isempty(intersect(el1, ic1)) || !isempty(intersect(el2, ic2))

for f ∈ all_factors
free_ids = get_free_indices(f)

if isempty(free_ids)
push!(reshaped, f)
else
if can_contract(d, f)
push!(reshaped, evaluate(Mult(), d, f))
elseif !isempty(indices_in_common(d, f))
common_index = only(indices_in_common(d, f))
target_index = if first(d.indices) == common_index
last(d.indices)
else
first(d.indices)
end

f = update_index(
f,
common_index,
target_index;
allow_shape_change = true,
)

push!(reshaped, f)
else
push!(reshaped, f)
end
end

push!(reshaped, f)
else
@assert false "Not implemented, please open an issue with your input"
end
end

arg2_ids = get_free_indices(arg2)

if length(arg2_ids) == 1
arg2 = update_index(
arg2,
only(arg2_ids),
only(target_indices);
allow_shape_change = true,
)
push!(reshaped, arg2)
else
@assert false "Not implemented, please open an issue with your input"
return to_binary_operation(Mult(), reshaped)
end

return to_binary_operation(Mult(), reshaped)
end

return op
Expand Down
Loading