From c20838176e8bea9e5a176c59c78bbce9051ec987 Mon Sep 17 00:00:00 2001 From: Naveen Saini Date: Fri, 27 Aug 2021 11:41:47 +0800 Subject: [PATCH 2/2] [X86] When storing v1i1/v2i1/v4i1 to memory, make sure we store zeros in the rest of the byte We can't store garbage in the unused bits. It possible that something like zextload from i1/i2/i4 is created to read the memory. Those zextloads would be legalized assuming the extra bits are 0. I'm not sure that the code in lowerStore is executed for the v1i1/v2i1/v4i1 case. It looks like the DAG combine in combineStore may have converted them to v8i1 first. And I think we're missing some cases to avoid going to the stack in the first place. But I don't have time to investigate those things at the moment so I wanted to focus on the correctness issue. Should fix PR48147. Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D91294 Upstream-Status: Backport [https://github.com/llvm/llvm-project/commit/a4124e455e641db1e18d4221d2dacb31953fd13b] Signed-off-by: Naveen Saini --- llvm/lib/Target/X86/X86ISelLowering.cpp | 19 ++++++++++++++----- llvm/lib/Target/X86/X86InstrAVX512.td | 3 --- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 56690c3c555b..7e673a3163b7 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23549,17 +23549,22 @@ static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 stores. if (StoredVal.getValueType().isVector() && StoredVal.getValueType().getVectorElementType() == MVT::i1) { - assert(StoredVal.getValueType().getVectorNumElements() <= 8 && - "Unexpected VT"); + unsigned NumElts = StoredVal.getValueType().getVectorNumElements(); + assert(NumElts <= 8 && "Unexpected VT"); assert(!St->isTruncatingStore() && "Expected non-truncating store"); assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && "Expected AVX512F without AVX512DQI"); + // We must pad with zeros to ensure we store zeroes to any unused bits. StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, DAG.getUNDEF(MVT::v16i1), StoredVal, DAG.getIntPtrConstant(0, dl)); StoredVal = DAG.getBitcast(MVT::i16, StoredVal); StoredVal = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, StoredVal); + // Make sure we store zeros in the extra bits. + if (NumElts < 8) + StoredVal = DAG.getZeroExtendInReg(StoredVal, dl, + MVT::getIntegerVT(NumElts)); return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), St->getPointerInfo(), St->getOriginalAlign(), @@ -44133,17 +44138,21 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, if (VT == MVT::v1i1 && VT == StVT && Subtarget.hasAVX512() && StoredVal.getOpcode() == ISD::SCALAR_TO_VECTOR && StoredVal.getOperand(0).getValueType() == MVT::i8) { - return DAG.getStore(St->getChain(), dl, StoredVal.getOperand(0), + SDValue Val = StoredVal.getOperand(0); + // We must store zeros to the unused bits. + Val = DAG.getZeroExtendInReg(Val, dl, MVT::i1); + return DAG.getStore(St->getChain(), dl, Val, St->getBasePtr(), St->getPointerInfo(), St->getOriginalAlign(), St->getMemOperand()->getFlags()); } // Widen v2i1/v4i1 stores to v8i1. - if ((VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT && + if ((VT == MVT::v1i1 || VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT && Subtarget.hasAVX512()) { unsigned NumConcats = 8 / VT.getVectorNumElements(); - SmallVector Ops(NumConcats, DAG.getUNDEF(VT)); + // We must store zeros to the unused bits. + SmallVector Ops(NumConcats, DAG.getConstant(0, dl, VT)); Ops[0] = StoredVal; StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops); return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index a3ad0b1c8dd6..aa1ccec02f2a 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -2871,9 +2871,6 @@ def : Pat<(i64 (bitconvert (v64i1 VK64:$src))), // Load/store kreg let Predicates = [HasDQI] in { - def : Pat<(store VK1:$src, addr:$dst), - (KMOVBmk addr:$dst, (COPY_TO_REGCLASS VK1:$src, VK8))>; - def : Pat<(v1i1 (load addr:$src)), (COPY_TO_REGCLASS (KMOVBkm addr:$src), VK1)>; def : Pat<(v2i1 (load addr:$src)), -- 2.17.1